File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+
15+ """Initialization for any Engine implementation."""
16+
17+ import jax
18+
19+
20+ def register_proxy_backend ():
21+ """Try to register IFRT Proxy backend if it's needed."""
22+ # TODO: find a more elegant way to do it.
23+ if jax .config .jax_platforms and "proxy" in jax .config .jax_platforms :
24+ try :
25+ jax .lib .xla_bridge .get_backend ("proxy" )
26+ except RuntimeError :
27+ try :
28+ from jaxlib .xla_extension import ifrt_proxy # pylint: disable=import-outside-toplevel
29+
30+ jax_backend_target = jax .config .read ("jax_backend_target" )
31+ jax ._src .xla_bridge .register_backend_factory ( # pylint: disable=protected-access
32+ "proxy" ,
33+ lambda : ifrt_proxy .get_client (
34+ jax_backend_target ,
35+ ifrt_proxy .ClientConnectionOptions (),
36+ ),
37+ priority = - 1 ,
38+ )
39+ print (f"Registered IFRT Proxy with address { jax_backend_target } " )
40+ except ImportError as e :
41+ print (f"Failed to register IFRT Proxy, exception: { e } " )
42+ pass
43+
44+
45+ register_proxy_backend ()
You can’t perform that action at this time.
0 commit comments