|
24 | 24 | from ..utils import logger |
25 | 25 |
|
26 | 26 | DLTS_HOSTFILE = "/job/hostfile" |
27 | | -EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", 'UCX'] |
| 27 | +EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", "UCX"] |
28 | 28 | DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env" |
29 | 29 | DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.'] |
30 | 30 | PDSH_MAX_FAN_OUT = 1024 |
@@ -104,6 +104,22 @@ def parse_args(args=None): |
104 | 104 | help="(optional) pass launcher specific arguments as a " |
105 | 105 | "single quoted argument.") |
106 | 106 |
|
| 107 | + parser.add_argument("--module", |
| 108 | + action="store_true", |
| 109 | + help="Change each process to interpret the launch " |
| 110 | + "script as a Python module, executing with the same " |
| 111 | + "behavior as 'python -m'.") |
| 112 | + |
| 113 | + parser.add_argument("--no_python", |
| 114 | + action="store_true", |
| 115 | + help="Skip prepending the training script with " |
| 116 | + "'python' - just execute it directly.") |
| 117 | + |
| 118 | + parser.add_argument("--no_local_rank", |
| 119 | + action="store_true", |
| 120 | + help="Do not pass local_rank as an argument when calling " |
| 121 | + "the user's training script.") |
| 122 | + |
107 | 123 | parser.add_argument("--force_multi", |
108 | 124 | action="store_true", |
109 | 125 | help="Force multi-node launcher mode, helps in cases where user " |
@@ -154,11 +170,9 @@ def fetch_hostfile(hostfile_path): |
154 | 170 |
|
155 | 171 | def parse_resource_filter(host_info, include_str="", exclude_str=""): |
156 | 172 | '''Parse an inclusion or exclusion string and filter a hostfile dictionary. |
157 | | -
|
158 | 173 | String format is NODE_SPEC[@NODE_SPEC ...], where |
159 | 174 | NODE_SPEC = NAME[:SLOT[,SLOT ...]]. |
160 | 175 | If :SLOT is omitted, include/exclude all slots on that host. |
161 | | -
|
162 | 176 | Examples: |
163 | 177 | include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and |
164 | 178 | slots [0, 2] on worker-1. |
@@ -326,7 +340,13 @@ def main(args=None): |
326 | 340 | "--master_port={}".format(args.master_port) |
327 | 341 | ] |
328 | 342 | if args.detect_nvlink_pairs: |
329 | | - deepspeed_launch += ["--detect_nvlink_pairs"] |
| 343 | + deepspeed_launch.append("--detect_nvlink_pairs") |
| 344 | + if args.no_python: |
| 345 | + deepspeed_launch.append("--no_python") |
| 346 | + if args.module: |
| 347 | + deepspeed_launch.append("--module") |
| 348 | + if args.no_local_rank: |
| 349 | + deepspeed_launch.append("--no_local_rank") |
330 | 350 | cmd = deepspeed_launch + [args.user_script] + args.user_args |
331 | 351 | else: |
332 | 352 | args.launcher = args.launcher.lower() |
@@ -358,7 +378,7 @@ def main(args=None): |
358 | 378 | if os.path.isfile(environ_file): |
359 | 379 | with open(environ_file, 'r') as fd: |
360 | 380 | for var in fd.readlines(): |
361 | | - key, val = var.split('=') |
| 381 | + key, val = var.split('=', maxsplit=1) |
362 | 382 | runner.add_export(key, val) |
363 | 383 |
|
364 | 384 | cmd = runner.get_cmd(env, active_resources) |
|
0 commit comments