Skip to content

Commit 3fdacc8

Browse files
zhihaoshan-googleZhihao Shan
andauthored
Register IFRT proxy backend when proxy is defined in the jax_platforms (#63)
* Register IFRT proxy backend when proxy is defined in the jax_platforms * fix lint --------- Co-authored-by: Zhihao Shan <zhihaoshan@google.com>
1 parent 7805b5d commit 3fdacc8

1 file changed

Lines changed: 32 additions & 0 deletions

File tree

jetstream/engine/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,35 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
"""Initialization for any Engine implementation."""
16+
17+
import jax
18+
19+
20+
def register_proxy_backend():
21+
"""Try to register IFRT Proxy backend if it's needed."""
22+
# TODO: find a more elegant way to do it.
23+
if jax.config.jax_platforms and "proxy" in jax.config.jax_platforms:
24+
try:
25+
jax.lib.xla_bridge.get_backend("proxy")
26+
except RuntimeError:
27+
try:
28+
from jaxlib.xla_extension import ifrt_proxy # pylint: disable=import-outside-toplevel
29+
30+
jax_backend_target = jax.config.read("jax_backend_target")
31+
jax._src.xla_bridge.register_backend_factory( # pylint: disable=protected-access
32+
"proxy",
33+
lambda: ifrt_proxy.get_client(
34+
jax_backend_target,
35+
ifrt_proxy.ClientConnectionOptions(),
36+
),
37+
priority=-1,
38+
)
39+
print(f"Registered IFRT Proxy with address {jax_backend_target}")
40+
except ImportError as e:
41+
print(f"Failed to register IFRT Proxy, exception: {e}")
42+
pass
43+
44+
45+
register_proxy_backend()

0 commit comments

Comments
 (0)