@@ -90,39 +90,22 @@ def execute_tasks_h5(
9090 if task_dict is not None and "shutdown" in task_dict and task_dict ["shutdown" ]:
9191 if task_dict ["wait" ] and wait :
9292 while len (memory_dict ) > 0 :
93- memory_dict = {
94- key : _check_task_output (
95- task_key = key ,
96- future_obj = value ,
97- cache_directory = cache_dir_dict [key ],
98- )
99- for key , value in memory_dict .items ()
100- if not value .done ()
101- }
93+ memory_dict = _refresh_memory_dict (
94+ memory_dict = memory_dict ,
95+ cache_dir_dict = cache_dir_dict ,
96+ )
10297 if not task_dict ["cancel_futures" ] and wait :
103- if (
104- terminate_function is not None
105- and terminate_function == terminate_subprocess
106- ):
107- for task in process_dict .values ():
108- terminate_function (task = task )
109- elif terminate_function is not None :
110- for queue_id in process_dict .values ():
111- terminate_function (
112- queue_id = queue_id ,
113- config_directory = pysqa_config_directory ,
114- backend = backend ,
115- )
98+ _cancel_processes (
99+ terminate_function = terminate_function ,
100+ process_dict = process_dict ,
101+ pysqa_config_directory = pysqa_config_directory ,
102+ backend = backend ,
103+ )
116104 else :
117- memory_dict = {
118- key : _check_task_output (
119- task_key = key ,
120- future_obj = value ,
121- cache_directory = cache_dir_dict [key ],
122- )
123- for key , value in memory_dict .items ()
124- if not value .done ()
125- }
105+ memory_dict = _refresh_memory_dict (
106+ memory_dict = memory_dict ,
107+ cache_dir_dict = cache_dir_dict ,
108+ )
126109 for value in memory_dict .values ():
127110 if not value .done ():
128111 value .cancel ()
@@ -193,15 +176,10 @@ def execute_tasks_h5(
193176 cache_dir_dict [task_key ] = cache_directory
194177 future_queue .task_done ()
195178 else :
196- memory_dict = {
197- key : _check_task_output (
198- task_key = key ,
199- future_obj = value ,
200- cache_directory = cache_dir_dict [key ],
201- )
202- for key , value in memory_dict .items ()
203- if not value .done ()
204- }
179+ memory_dict = _refresh_memory_dict (
180+ memory_dict = memory_dict ,
181+ cache_dir_dict = cache_dir_dict ,
182+ )
205183
206184
207185def _check_task_output (
@@ -275,3 +253,52 @@ def _convert_args_and_kwargs(
275253 else :
276254 task_kwargs [key ] = arg
277255 return task_args , task_kwargs , future_wait_key_lst
256+
257+
258+ def _refresh_memory_dict (memory_dict : dict , cache_dir_dict : dict ) -> dict :
259+ """
260+ Refresh memory dictionary
261+
262+ Args:
263+ memory_dict (dict): dictionary with task keys and future objects
264+ cache_dir_dict (dict): dictionary with task keys and cache directories
265+
266+ Returns:
267+ dict: Updated memory dictionary
268+ """
269+ return {
270+ key : _check_task_output (
271+ task_key = key ,
272+ future_obj = value ,
273+ cache_directory = cache_dir_dict [key ],
274+ )
275+ for key , value in memory_dict .items ()
276+ if not value .done ()
277+ }
278+
279+
280+ def _cancel_processes (
281+ process_dict : dict ,
282+ terminate_function : Optional [Callable ] = None ,
283+ pysqa_config_directory : Optional [str ] = None ,
284+ backend : Optional [str ] = None ,
285+ ):
286+ """
287+ Cancel processes
288+
289+ Args:
290+ process_dict (dict): dictionary with task keys and process reference.
291+ terminate_function (callable): The function to terminate the tasks.
292+ pysqa_config_directory (str): path to the pysqa config directory (only for pysqa based backend).
293+ backend (str): name of the backend used to spawn tasks.
294+ """
295+ if terminate_function is not None and terminate_function == terminate_subprocess :
296+ for task in process_dict .values ():
297+ terminate_function (task = task )
298+ elif terminate_function is not None and backend is not None :
299+ for queue_id in process_dict .values ():
300+ terminate_function (
301+ queue_id = queue_id ,
302+ config_directory = pysqa_config_directory ,
303+ backend = backend ,
304+ )
0 commit comments