Skip to content

Commit e52d249

Browse files
committed
Merge remote-tracking branch 'determined/eleuther_dai' into igor
2 parents 3987139 + 15b4b10 commit e52d249

3 files changed

Lines changed: 67 additions & 14 deletions

File tree

deepspeed/launcher/launch.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
"""
33
DeepSpeed launcher, this is similar to torch.distributed.launch but supports
44
additional features such as abitrary gpu exclusion.
5-
65
deepspeed.launcher.launch is intended to be run on a single worker node and
76
will spawn several worker sub-processes depending on how many devices/ranks
87
are on the worker.
@@ -53,6 +52,22 @@ def parse_args():
5352
parser.add_argument("--detect_nvlink_pairs", action="store_true",
5453
help="autodetects nvlink pairs and remaps CUDA_VISIBLE_DEVICES along the fastest connections")
5554

55+
parser.add_argument("--module",
56+
action="store_true",
57+
help="Change each process to interpret the launch "
58+
"script as a Python module, executing with the same "
59+
"behavior as 'python -m'.")
60+
61+
parser.add_argument("--no_python",
62+
action="store_true",
63+
help="Skip prepending the training script with "
64+
"'python' - just execute it directly.")
65+
66+
parser.add_argument("--no_local_rank",
67+
action="store_true",
68+
help="Do not pass local_rank as an argument when calling "
69+
"the user's training script.")
70+
5671
# positional
5772
parser.add_argument("training_script",
5873
type=str,
@@ -117,6 +132,9 @@ def main():
117132
current_env["MASTER_ADDR"] = args.master_addr
118133
current_env["MASTER_PORT"] = str(args.master_port)
119134
current_env["WORLD_SIZE"] = str(dist_world_size)
135+
current_env["CROSS_RANK"] = str(args.node_rank)
136+
current_env["CROSS_SIZE"] = str(args.nnodes)
137+
current_env["LOCAL_SIZE"] = str(num_local_procs)
120138

121139
processes = []
122140
for local_rank in range(0, num_local_procs):
@@ -126,12 +144,20 @@ def main():
126144
current_env["LOCAL_RANK"] = str(local_rank)
127145

128146
# spawn the processes
129-
cmd = [
130-
sys.executable,
131-
"-u",
132-
args.training_script,
133-
"--local_rank={}".format(local_rank)
134-
] + args.training_script_args
147+
cmd = []
148+
if not args.no_python:
149+
cmd = [sys.executable, "-u"]
150+
if args.module:
151+
cmd.append("-m")
152+
else:
153+
if args.module:
154+
raise ValueError("Don't use both the '--no_python' flag"
155+
" and the '--module' flag at the same time.")
156+
cmd.append(args.training_script)
157+
# A user may not want to pass local_rank as a keyword arg so we make this optional.
158+
if not args.no_local_rank:
159+
cmd.append(f"--local_rank={local_rank}")
160+
cmd += args.training_script_args
135161

136162
sig_names = {2: "SIGINT", 15: "SIGTERM"}
137163
last_return_code = None

deepspeed/launcher/multinode_runner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import subprocess
55
import warnings
66
from abc import ABC, abstractmethod
7+
from shlex import quote
78

89
from ..utils import logger
910
from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE
@@ -56,7 +57,7 @@ def get_cmd(self, environment, active_resources):
5657

5758
exports = ""
5859
for key, val in self.exports.items():
59-
exports += "export {}={}; ".format(key, val)
60+
exports += f"export {key}={quote(val)}; "
6061

6162
deepspeed_launch = [
6263
exports,
@@ -71,7 +72,13 @@ def get_cmd(self, environment, active_resources):
7172
"--master_port={}".format(self.args.master_port)
7273
]
7374
if self.args.detect_nvlink_pairs:
74-
deepspeed_launch += ["--detect_nvlink_pairs"]
75+
deepspeed_launch.append("--detect_nvlink_pairs")
76+
if self.args.no_python:
77+
deepspeed_launch.append("--no_python")
78+
if self.args.module:
79+
deepspeed_launch.append("--module")
80+
if self.args.no_local_rank:
81+
deepspeed_launch.append("--no_local_rank")
7582

7683
return pdsh_cmd_args + deepspeed_launch + [self.user_script
7784
] + self.user_arguments

deepspeed/launcher/runner.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ..utils import logger
2525

2626
DLTS_HOSTFILE = "/job/hostfile"
27-
EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", 'UCX']
27+
EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", "UCX"]
2828
DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env"
2929
DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.']
3030
PDSH_MAX_FAN_OUT = 1024
@@ -104,6 +104,22 @@ def parse_args(args=None):
104104
help="(optional) pass launcher specific arguments as a "
105105
"single quoted argument.")
106106

107+
parser.add_argument("--module",
108+
action="store_true",
109+
help="Change each process to interpret the launch "
110+
"script as a Python module, executing with the same "
111+
"behavior as 'python -m'.")
112+
113+
parser.add_argument("--no_python",
114+
action="store_true",
115+
help="Skip prepending the training script with "
116+
"'python' - just execute it directly.")
117+
118+
parser.add_argument("--no_local_rank",
119+
action="store_true",
120+
help="Do not pass local_rank as an argument when calling "
121+
"the user's training script.")
122+
107123
parser.add_argument("--force_multi",
108124
action="store_true",
109125
help="Force multi-node launcher mode, helps in cases where user "
@@ -154,11 +170,9 @@ def fetch_hostfile(hostfile_path):
154170

155171
def parse_resource_filter(host_info, include_str="", exclude_str=""):
156172
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
157-
158173
String format is NODE_SPEC[@NODE_SPEC ...], where
159174
NODE_SPEC = NAME[:SLOT[,SLOT ...]].
160175
If :SLOT is omitted, include/exclude all slots on that host.
161-
162176
Examples:
163177
include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and
164178
slots [0, 2] on worker-1.
@@ -326,7 +340,13 @@ def main(args=None):
326340
"--master_port={}".format(args.master_port)
327341
]
328342
if args.detect_nvlink_pairs:
329-
deepspeed_launch += ["--detect_nvlink_pairs"]
343+
deepspeed_launch.append("--detect_nvlink_pairs")
344+
if args.no_python:
345+
deepspeed_launch.append("--no_python")
346+
if args.module:
347+
deepspeed_launch.append("--module")
348+
if args.no_local_rank:
349+
deepspeed_launch.append("--no_local_rank")
330350
cmd = deepspeed_launch + [args.user_script] + args.user_args
331351
else:
332352
args.launcher = args.launcher.lower()
@@ -358,7 +378,7 @@ def main(args=None):
358378
if os.path.isfile(environ_file):
359379
with open(environ_file, 'r') as fd:
360380
for var in fd.readlines():
361-
key, val = var.split('=')
381+
key, val = var.split('=', maxsplit=1)
362382
runner.add_export(key, val)
363383

364384
cmd = runner.get_cmd(env, active_resources)

0 commit comments

Comments
 (0)