@@ -27,6 +27,8 @@ def __init__(
2727 cores : int = 1 ,
2828 threads_per_core : int = 1 ,
2929 gpus_per_core : int = 0 ,
30+ num_nodes : Optional [int ] = None ,
31+ exclusive : bool = False ,
3032 openmpi_oversubscribe : bool = False ,
3133 slurm_cmd_args : Optional [list [str ]] = None ,
3234 ):
@@ -38,6 +40,8 @@ def __init__(
3840 cores (int, optional): The number of cores to use. Defaults to 1.
3941 threads_per_core (int, optional): The number of threads per core. Defaults to 1.
4042 gpus_per_core (int, optional): The number of GPUs per core. Defaults to 0.
43+ num_nodes (int, optional): The number of compute nodes to use for executing the task. Defaults to None.
44+ exclusive (bool): Whether to exclusively reserve the compute nodes, or allow sharing compute notes. Defaults to False.
4145 openmpi_oversubscribe (bool, optional): Whether to oversubscribe the cores. Defaults to False.
4246 slurm_cmd_args (list[str], optional): Additional command line arguments. Defaults to [].
4347 """
@@ -49,6 +53,8 @@ def __init__(
4953 )
5054 self ._gpus_per_core = gpus_per_core
5155 self ._slurm_cmd_args = slurm_cmd_args
56+ self ._num_nodes = num_nodes
57+ self ._exclusive = exclusive
5258
5359 def generate_command (self , command_lst : list [str ]) -> list [str ]:
5460 """
@@ -65,6 +71,8 @@ def generate_command(self, command_lst: list[str]) -> list[str]:
6571 cwd = self ._cwd ,
6672 threads_per_core = self ._threads_per_core ,
6773 gpus_per_core = self ._gpus_per_core ,
74+ num_nodes = self ._num_nodes ,
75+ exclusive = self ._exclusive ,
6876 openmpi_oversubscribe = self ._openmpi_oversubscribe ,
6977 slurm_cmd_args = self ._slurm_cmd_args ,
7078 )
@@ -78,6 +86,8 @@ def generate_slurm_command(
7886 cwd : Optional [str ],
7987 threads_per_core : int = 1 ,
8088 gpus_per_core : int = 0 ,
89+ num_nodes : Optional [int ] = None ,
90+ exclusive : bool = False ,
8191 openmpi_oversubscribe : bool = False ,
8292 slurm_cmd_args : Optional [list [str ]] = None ,
8393) -> list [str ]:
@@ -89,6 +99,8 @@ def generate_slurm_command(
8999 cwd (str): The current working directory.
90100 threads_per_core (int, optional): The number of threads per core. Defaults to 1.
91101 gpus_per_core (int, optional): The number of GPUs per core. Defaults to 0.
102+ num_nodes (int, optional): The number of compute nodes to use for executing the task. Defaults to None.
103+ exclusive (bool): Whether to exclusively reserve the compute nodes, or allow sharing compute notes. Defaults to False.
92104 openmpi_oversubscribe (bool, optional): Whether to oversubscribe the cores. Defaults to False.
93105 slurm_cmd_args (list[str], optional): Additional command line arguments. Defaults to [].
94106
@@ -98,10 +110,14 @@ def generate_slurm_command(
98110 command_prepend_lst = [SLURM_COMMAND , "-n" , str (cores )]
99111 if cwd is not None :
100112 command_prepend_lst += ["-D" , cwd ]
113+ if num_nodes is not None :
114+ command_prepend_lst += ["-N" , str (num_nodes )]
101115 if threads_per_core > 1 :
102116 command_prepend_lst += ["--cpus-per-task=" + str (threads_per_core )]
103117 if gpus_per_core > 0 :
104118 command_prepend_lst += ["--gpus-per-task=" + str (gpus_per_core )]
119+ if exclusive :
120+ command_prepend_lst += ["--exact" ]
105121 if openmpi_oversubscribe :
106122 command_prepend_lst += ["--oversubscribe" ]
107123 if slurm_cmd_args is not None and len (slurm_cmd_args ) > 0 :
0 commit comments