@@ -305,53 +305,86 @@ def test_stop_server_does_nothing_if_server_exists(self):
305305 profiling .start_server (9000 )
306306 profiling .stop_server () # Should not raise
307307
308- def test_monkey_patch_jax (self ):
309- original_jax_start_trace = jax .profiler .start_trace
310- original_jax_stop_trace = jax .profiler .stop_trace
311- original_jax_start_server = jax .profiler .start_server
312- original_jax_stop_server = jax .profiler .stop_server
308+ def _setup_monkey_patch (self ):
309+ """Saves originals, applies monkey patch, and sets up mocks."""
310+ targets = [
311+ (jax .profiler , "start_trace" ),
312+ (jax .profiler , "stop_trace" ),
313+ (jax .profiler , "start_server" ),
314+ (jax .profiler , "stop_server" ),
315+ (jax ._src .profiler , "start_trace" ),
316+ (jax ._src .profiler , "stop_trace" ),
317+ ]
318+ original_jax_funcs = {}
319+ for module , func_name in targets :
320+ original_func = getattr (module , func_name )
321+ original_jax_funcs [(module , func_name )] = original_func
322+ self .addCleanup (setattr , module , func_name , original_func )
313323
314324 profiling .monkey_patch_jax ()
315325
316- self .assertNotEqual (jax .profiler .start_trace , original_jax_start_trace )
317- self .assertNotEqual (jax .profiler .stop_trace , original_jax_stop_trace )
318- self .assertNotEqual (jax .profiler .start_server , original_jax_start_server )
319- self .assertNotEqual (jax .profiler .stop_server , original_jax_stop_server )
320-
321- with mock .patch .object (
322- profiling , "start_trace" , autospec = True
323- ) as mock_pw_start_trace :
324- jax .profiler .start_trace ("gs://bucket/dir" )
325- mock_pw_start_trace .assert_called_once_with (
326- "gs://bucket/dir" ,
327- create_perfetto_link = False ,
328- create_perfetto_trace = False ,
329- profiler_options = None ,
326+ for module , func_name in targets :
327+ self .assertNotEqual (
328+ getattr (module , func_name ),
329+ original_jax_funcs [(module , func_name )],
330330 )
331331
332- with mock .patch .object (
333- profiling , "stop_trace" , autospec = True
334- ) as mock_pw_stop_trace :
335- jax .profiler .stop_trace ()
336- mock_pw_stop_trace .assert_called_once ()
337-
338- with mock .patch .object (
339- profiling , "start_server" , autospec = True
340- ) as mock_pw_start_server :
341- jax .profiler .start_server (1234 )
342- mock_pw_start_server .assert_called_once_with (1234 )
343-
344- with mock .patch .object (
345- profiling , "stop_server" , autospec = True
346- ) as mock_pw_stop_server :
347- jax .profiler .stop_server ()
348- mock_pw_stop_server .assert_called_once ()
349-
350- # Restore original jax functions
351- jax .profiler .start_trace = original_jax_start_trace
352- jax .profiler .stop_trace = original_jax_stop_trace
353- jax .profiler .start_server = original_jax_start_server
354- jax .profiler .stop_server = original_jax_stop_server
332+ mocks = {
333+ "start_trace" : self .enter_context (
334+ mock .patch .object (profiling , "start_trace" , autospec = True )
335+ ),
336+ "stop_trace" : self .enter_context (
337+ mock .patch .object (profiling , "stop_trace" , autospec = True )
338+ ),
339+ "start_server" : self .enter_context (
340+ mock .patch .object (profiling , "start_server" , autospec = True )
341+ ),
342+ "stop_server" : self .enter_context (
343+ mock .patch .object (profiling , "stop_server" , autospec = True )
344+ ),
345+ }
346+ return mocks
347+
348+ @parameterized .named_parameters (
349+ dict (testcase_name = "jax_profiler" , profiler_module = jax .profiler ),
350+ dict (testcase_name = "jax_src_profiler" , profiler_module = jax ._src .profiler ),
351+ )
352+ def test_monkey_patched_start_trace (self , profiler_module ):
353+ mocks = self ._setup_monkey_patch ()
354+
355+ profiler_module .start_trace ("gs://bucket/dir" )
356+
357+ mocks ["start_trace" ].assert_called_once_with (
358+ "gs://bucket/dir" ,
359+ create_perfetto_link = False ,
360+ create_perfetto_trace = False ,
361+ profiler_options = None ,
362+ )
363+
364+ @parameterized .named_parameters (
365+ dict (testcase_name = "jax_profiler" , profiler_module = jax .profiler ),
366+ dict (testcase_name = "jax_src_profiler" , profiler_module = jax ._src .profiler ),
367+ )
368+ def test_monkey_patched_stop_trace (self , profiler_module ):
369+ mocks = self ._setup_monkey_patch ()
370+
371+ profiler_module .stop_trace ()
372+
373+ mocks ["stop_trace" ].assert_called_once ()
374+
375+ def test_monkey_patched_start_server (self ):
376+ mocks = self ._setup_monkey_patch ()
377+
378+ jax .profiler .start_server (1234 )
379+
380+ mocks ["start_server" ].assert_called_once_with (1234 )
381+
382+ def test_monkey_patched_stop_server (self ):
383+ mocks = self ._setup_monkey_patch ()
384+
385+ jax .profiler .stop_server ()
386+
387+ mocks ["stop_server" ].assert_called_once ()
355388
356389 def test_create_profile_request_no_options (self ):
357390 request = profiling ._create_profile_request ("gs://bucket/dir" )
@@ -389,6 +422,7 @@ def test_create_profile_request_no_options(self):
389422 },
390423 },),
391424 )
425+
392426 def test_start_pathways_trace_from_profile_request (self , profile_request ):
393427 profiling ._start_pathways_trace_from_profile_request (profile_request )
394428
@@ -412,6 +446,15 @@ def test_original_stop_trace_called_on_stop_failure(self):
412446 profiling .stop_trace ()
413447 self .mock_original_stop_trace .assert_called_once ()
414448
449+ def test_jax_profiler_trace_calls_patched_functions (self ):
450+ mocks = self ._setup_monkey_patch ()
451+
452+ with jax .profiler .trace ("gs://bucket/dir" ):
453+ pass
454+
455+ mocks ["start_trace" ].assert_called_once ()
456+ mocks ["stop_trace" ].assert_called_once ()
457+
415458
416459if __name__ == "__main__" :
417460 absltest .main ()
0 commit comments