From d73ebad669b3c8c993d319d4cda54ff73184ca2b Mon Sep 17 00:00:00 2001 From: Zhihao Shan Date: Fri, 26 Apr 2024 13:33:48 -0700 Subject: [PATCH 1/2] Register IFRT proxy backend when proxy is defined in the jax_platforms --- jetstream/engine/__init__.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/jetstream/engine/__init__.py b/jetstream/engine/__init__.py index 6d5e14bc..750d2c34 100644 --- a/jetstream/engine/__init__.py +++ b/jetstream/engine/__init__.py @@ -11,3 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import jax + + +def register_proxy_backend(): + """Try to register IFRT Proxy backend if it's needed.""" + # TODO: find a more elegant way to do it. + if jax.config.jax_platforms and 'proxy' in jax.config.jax_platforms: + try: + jax.lib.xla_bridge.get_backend("proxy") + except RuntimeError: + try: + from jaxlib.xla_extension import ifrt_proxy + jax_backend_target = jax.config.read("jax_backend_target") + jax._src.xla_bridge.register_backend_factory( # pylint: disable=protected-access + "proxy", + lambda: ifrt_proxy.get_client( + jax_backend_target, + ifrt_proxy.ClientConnectionOptions(), + ), + priority=-1, + ) + print(f"Registered IFRT Proxy backend with address {jax_backend_target}") + except Exception as e: + print(f"Failed to register IFRT Proxy, exception: {e}") + pass + +register_proxy_backend() From b66fc8c70f16db9a7f14544ea440bf41cecd0227 Mon Sep 17 00:00:00 2001 From: Zhihao Shan Date: Fri, 26 Apr 2024 14:11:34 -0700 Subject: [PATCH 2/2] fix lint --- jetstream/engine/__init__.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/jetstream/engine/__init__.py b/jetstream/engine/__init__.py index 750d2c34..2ed0398d 100644 --- a/jetstream/engine/__init__.py +++ b/jetstream/engine/__init__.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Initialization for any Engine implementation.""" + import jax def register_proxy_backend(): """Try to register IFRT Proxy backend if it's needed.""" # TODO: find a more elegant way to do it. - if jax.config.jax_platforms and 'proxy' in jax.config.jax_platforms: + if jax.config.jax_platforms and "proxy" in jax.config.jax_platforms: try: jax.lib.xla_bridge.get_backend("proxy") except RuntimeError: try: - from jaxlib.xla_extension import ifrt_proxy + from jaxlib.xla_extension import ifrt_proxy # pylint: disable=import-outside-toplevel + jax_backend_target = jax.config.read("jax_backend_target") jax._src.xla_bridge.register_backend_factory( # pylint: disable=protected-access "proxy", @@ -33,9 +36,10 @@ def register_proxy_backend(): ), priority=-1, ) - print(f"Registered IFRT Proxy backend with address {jax_backend_target}") - except Exception as e: + print(f"Registered IFRT Proxy with address {jax_backend_target}") + except ImportError as e: print(f"Failed to register IFRT Proxy, exception: {e}") pass + register_proxy_backend()