Skip to content

Commit fc736dd

Browse files
lukebaumanncopybara-github
authored andcommitted
Patch internal JAX profiler functions (enabling jax.profiler.trace) and add a test for jax.profiler.trace.
The `jax.profiler.trace` context manager uses internal `jax._src.profiler` functions. This change ensures that these internal functions are also patched by `pathwaysutils.profiling.monkey_patch_jax` to correctly intercept profiling calls. A new test is added to verify that `with jax.profiler.trace(...)` now triggers the patched Pathways profiling functions. PiperOrigin-RevId: 845839776
1 parent 88572e4 commit fc736dd

2 files changed

Lines changed: 87 additions & 42 deletions

File tree

pathwaysutils/profiling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,14 @@ def start_trace_patch(
286286
)
287287

288288
jax.profiler.start_trace = start_trace_patch
289+
jax._src.profiler.start_trace = start_trace_patch # pylint: disable=protected-access
289290

290291
def stop_trace_patch() -> None:
291292
_logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
292293
return stop_trace()
293294

294295
jax.profiler.stop_trace = stop_trace_patch
296+
jax._src.profiler.stop_trace = stop_trace_patch # pylint: disable=protected-access
295297

296298
def start_server_patch(port: int):
297299
_logger.debug(

pathwaysutils/test/profiling_test.py

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -305,53 +305,86 @@ def test_stop_server_does_nothing_if_server_exists(self):
305305
profiling.start_server(9000)
306306
profiling.stop_server() # Should not raise
307307

308-
def test_monkey_patch_jax(self):
309-
original_jax_start_trace = jax.profiler.start_trace
310-
original_jax_stop_trace = jax.profiler.stop_trace
311-
original_jax_start_server = jax.profiler.start_server
312-
original_jax_stop_server = jax.profiler.stop_server
308+
def _setup_monkey_patch(self):
309+
"""Saves originals, applies monkey patch, and sets up mocks."""
310+
targets = [
311+
(jax.profiler, "start_trace"),
312+
(jax.profiler, "stop_trace"),
313+
(jax.profiler, "start_server"),
314+
(jax.profiler, "stop_server"),
315+
(jax._src.profiler, "start_trace"),
316+
(jax._src.profiler, "stop_trace"),
317+
]
318+
original_jax_funcs = {}
319+
for module, func_name in targets:
320+
original_func = getattr(module, func_name)
321+
original_jax_funcs[(module, func_name)] = original_func
322+
self.addCleanup(setattr, module, func_name, original_func)
313323

314324
profiling.monkey_patch_jax()
315325

316-
self.assertNotEqual(jax.profiler.start_trace, original_jax_start_trace)
317-
self.assertNotEqual(jax.profiler.stop_trace, original_jax_stop_trace)
318-
self.assertNotEqual(jax.profiler.start_server, original_jax_start_server)
319-
self.assertNotEqual(jax.profiler.stop_server, original_jax_stop_server)
320-
321-
with mock.patch.object(
322-
profiling, "start_trace", autospec=True
323-
) as mock_pw_start_trace:
324-
jax.profiler.start_trace("gs://bucket/dir")
325-
mock_pw_start_trace.assert_called_once_with(
326-
"gs://bucket/dir",
327-
create_perfetto_link=False,
328-
create_perfetto_trace=False,
329-
profiler_options=None,
326+
for module, func_name in targets:
327+
self.assertNotEqual(
328+
getattr(module, func_name),
329+
original_jax_funcs[(module, func_name)],
330330
)
331331

332-
with mock.patch.object(
333-
profiling, "stop_trace", autospec=True
334-
) as mock_pw_stop_trace:
335-
jax.profiler.stop_trace()
336-
mock_pw_stop_trace.assert_called_once()
337-
338-
with mock.patch.object(
339-
profiling, "start_server", autospec=True
340-
) as mock_pw_start_server:
341-
jax.profiler.start_server(1234)
342-
mock_pw_start_server.assert_called_once_with(1234)
343-
344-
with mock.patch.object(
345-
profiling, "stop_server", autospec=True
346-
) as mock_pw_stop_server:
347-
jax.profiler.stop_server()
348-
mock_pw_stop_server.assert_called_once()
349-
350-
# Restore original jax functions
351-
jax.profiler.start_trace = original_jax_start_trace
352-
jax.profiler.stop_trace = original_jax_stop_trace
353-
jax.profiler.start_server = original_jax_start_server
354-
jax.profiler.stop_server = original_jax_stop_server
332+
mocks = {
333+
"start_trace": self.enter_context(
334+
mock.patch.object(profiling, "start_trace", autospec=True)
335+
),
336+
"stop_trace": self.enter_context(
337+
mock.patch.object(profiling, "stop_trace", autospec=True)
338+
),
339+
"start_server": self.enter_context(
340+
mock.patch.object(profiling, "start_server", autospec=True)
341+
),
342+
"stop_server": self.enter_context(
343+
mock.patch.object(profiling, "stop_server", autospec=True)
344+
),
345+
}
346+
return mocks
347+
348+
@parameterized.named_parameters(
349+
dict(testcase_name="jax_profiler", profiler_module=jax.profiler),
350+
dict(testcase_name="jax_src_profiler", profiler_module=jax._src.profiler),
351+
)
352+
def test_monkey_patched_start_trace(self, profiler_module):
353+
mocks = self._setup_monkey_patch()
354+
355+
profiler_module.start_trace("gs://bucket/dir")
356+
357+
mocks["start_trace"].assert_called_once_with(
358+
"gs://bucket/dir",
359+
create_perfetto_link=False,
360+
create_perfetto_trace=False,
361+
profiler_options=None,
362+
)
363+
364+
@parameterized.named_parameters(
365+
dict(testcase_name="jax_profiler", profiler_module=jax.profiler),
366+
dict(testcase_name="jax_src_profiler", profiler_module=jax._src.profiler),
367+
)
368+
def test_monkey_patched_stop_trace(self, profiler_module):
369+
mocks = self._setup_monkey_patch()
370+
371+
profiler_module.stop_trace()
372+
373+
mocks["stop_trace"].assert_called_once()
374+
375+
def test_monkey_patched_start_server(self):
376+
mocks = self._setup_monkey_patch()
377+
378+
jax.profiler.start_server(1234)
379+
380+
mocks["start_server"].assert_called_once_with(1234)
381+
382+
def test_monkey_patched_stop_server(self):
383+
mocks = self._setup_monkey_patch()
384+
385+
jax.profiler.stop_server()
386+
387+
mocks["stop_server"].assert_called_once()
355388

356389
def test_create_profile_request_no_options(self):
357390
request = profiling._create_profile_request("gs://bucket/dir")
@@ -389,6 +422,7 @@ def test_create_profile_request_no_options(self):
389422
},
390423
},),
391424
)
425+
392426
def test_start_pathways_trace_from_profile_request(self, profile_request):
393427
profiling._start_pathways_trace_from_profile_request(profile_request)
394428

@@ -412,6 +446,15 @@ def test_original_stop_trace_called_on_stop_failure(self):
412446
profiling.stop_trace()
413447
self.mock_original_stop_trace.assert_called_once()
414448

449+
def test_jax_profiler_trace_calls_patched_functions(self):
450+
mocks = self._setup_monkey_patch()
451+
452+
with jax.profiler.trace("gs://bucket/dir"):
453+
pass
454+
455+
mocks["start_trace"].assert_called_once()
456+
mocks["stop_trace"].assert_called_once()
457+
415458

416459
if __name__ == "__main__":
417460
absltest.main()

0 commit comments

Comments
 (0)