22
33from collections .abc import Iterator , Mapping
44import contextlib
5+ import gc
56import logging
67import os
78import random
1011from typing import Any
1112
1213import jax
14+ import jax .extend .backend as jax_backend
1315import pathwaysutils
1416from pathwaysutils .experimental .shared_pathways_service import gke_utils
1517from pathwaysutils .experimental .shared_pathways_service import validators
2729_JAX_PLATFORM_PROXY = "proxy"
2830_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
2931_JAX_BACKEND_TARGET_HOSTNAME = "grpc://localhost"
32+ _DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest"
3033
3134_logger = logging .getLogger (__name__ )
3235
@@ -36,6 +39,7 @@ def _deploy_pathways_proxy_server(
3639 proxy_job_name : str ,
3740 expected_instances : Mapping [Any , Any ],
3841 gcs_scratch_location : str ,
42+ proxy_server_image : str ,
3943) -> None :
4044 """Deploys the Pathways proxy pods to the GKE cluster.
4145
@@ -45,6 +49,7 @@ def _deploy_pathways_proxy_server(
4549 expected_instances: A dictionary mapping instance types to the number of
4650 instances.
4751 gcs_scratch_location: The Google Cloud Storage location to use.
52+ proxy_server_image: The image to use for the proxy server.
4853
4954 Raises:
5055 subprocess.CalledProcessError: If the kubectl command fails.
@@ -70,6 +75,7 @@ def _deploy_pathways_proxy_server(
7075 PATHWAYS_HEAD_PORT = pathways_head_port ,
7176 EXPECTED_INSTANCES = instances_str ,
7277 GCS_SCRATCH_LOCATION = gcs_scratch_location ,
78+ PROXY_SERVER_IMAGE = proxy_server_image ,
7379 )
7480
7581 _logger .info ("Deploying Pathways proxy: %s" , proxy_job_name )
@@ -89,6 +95,8 @@ class _ISCPathways:
8995 pathways_service: The service name and port of the Pathways head pod.
9096 expected_tpu_instances: A dictionary mapping TPU machine types to the number
9197 of instances.
98+ proxy_job_name: The name to use for the deployed proxy.
99+ proxy_server_image: The image to use for the proxy server.
92100 """
93101
94102 def __init__ (
@@ -99,7 +107,8 @@ def __init__(
99107 gcs_bucket : str ,
100108 pathways_service : str ,
101109 expected_tpu_instances : Mapping [Any , Any ],
102- proxy_job_name : str | None ,
110+ proxy_job_name : str ,
111+ proxy_server_image : str ,
103112 ):
104113 """Initializes the TPU manager."""
105114 self .cluster = cluster
@@ -108,13 +117,10 @@ def __init__(
108117 self .bucket = gcs_bucket
109118 self .pathways_service = pathways_service
110119 self .expected_tpu_instances = expected_tpu_instances
111- suffix = "" .join (
112- random .choices (string .ascii_lowercase + string .digits , k = 5 )
113- )
114- user = os .environ .get ("USER" , "user" )
115- self ._proxy_job_name = proxy_job_name or f"isc-proxy-{ user } -{ suffix } "
120+ self ._proxy_job_name = proxy_job_name
116121 self ._port_forward_process = None
117122 self ._proxy_port = None
123+ self .proxy_server_image = proxy_server_image
118124
119125 def __repr__ (self ):
120126 return (
@@ -133,6 +139,7 @@ def __enter__(self):
133139 proxy_job_name = self ._proxy_job_name ,
134140 expected_instances = self .expected_tpu_instances ,
135141 gcs_scratch_location = self .bucket ,
142+ proxy_server_image = self .proxy_server_image ,
136143 )
137144 # Print a link to Cloud Logging
138145 cloud_logging_link = gke_utils .get_log_link (
@@ -172,7 +179,16 @@ def __exit__(self, exc_type, exc_value, traceback):
172179
173180 def _cleanup (self ):
174181 """Cleans up resources created by the ISCPathways context."""
182+ # 1. Clear JAX caches and run garbage collection.
183+ _logger .info ("Starting Pathways proxy cleanup." )
184+ jax_backend .clear_backends ()
185+ jax .clear_caches ()
186+ gc .collect ()
187+ _logger .info ("Cleared JAX caches and ran garbage collection." )
188+
189+ # 2. Terminate the port forwarding process.
175190 if self ._port_forward_process :
191+ _logger .info ("Terminating port forwarding process..." )
176192 self ._port_forward_process .terminate ()
177193 try :
178194 self ._port_forward_process .wait (timeout = 10 )
@@ -183,19 +199,23 @@ def _cleanup(self):
183199 e ,
184200 )
185201
186- _logger .info ("Deleting Pathways proxy" )
202+ # 3. Delete the proxy GKE job.
203+ _logger .info ("Deleting Pathways proxy..." )
187204 gke_utils .delete_gke_job (self ._proxy_job_name )
205+ _logger .info ("Pathways proxy GKE job deletion complete." )
188206
189207
190208@contextlib .contextmanager
191209def connect (
192- * , cluster : str ,
210+ * ,
211+ cluster : str ,
193212 project : str ,
194213 region : str ,
195214 gcs_bucket : str ,
196215 pathways_service : str ,
197216 expected_tpu_instances : Mapping [str , int ],
198217 proxy_job_name : str | None = None ,
218+ proxy_server_image : str = _DEFAULT_PROXY_IMAGE ,
199219) -> Iterator ["_ISCPathways" ]:
200220 """Connects to a Pathways server if the cluster exists. If not, creates it.
201221
@@ -209,17 +229,26 @@ def connect(
209229 of instances. For example: {"tpuv6e:2x2": 2}
210230 proxy_job_name: The name to use for the deployed proxy. If not provided, a
211231 random name will be generated.
232+ proxy_server_image: The proxy server image to use. If not provided, a
233+ default will be used.
212234
213235 Yields:
214236 The Pathways manager.
215237 """
216238 _logger .info ("Validating Pathways service and TPU instances..." )
217239 validators .validate_pathways_service (pathways_service )
218240 validators .validate_tpu_instances (expected_tpu_instances )
241+ validators .validate_proxy_server_image (proxy_server_image )
219242 _logger .info ("Validation complete." )
220243 gke_utils .fetch_cluster_credentials (
221244 cluster_name = cluster , project_id = project , location = region
222245 )
246+ proxy_job_name = (
247+ proxy_job_name or f"isc-proxy-{ os .environ .get ('USER' , 'user' )} -{ '' .join (
248+ random .choices (string .ascii_lowercase + string .digits , k = 5 )
249+ )} "
250+ )
251+
223252 _logger .info ("Starting ISCPathways context." )
224253 with _ISCPathways (
225254 cluster = cluster ,
@@ -229,5 +258,6 @@ def connect(
229258 pathways_service = pathways_service ,
230259 expected_tpu_instances = expected_tpu_instances ,
231260 proxy_job_name = proxy_job_name ,
261+ proxy_server_image = proxy_server_image ,
232262 ) as t :
233263 yield t
0 commit comments