Skip to content

Commit a1437f2

Browse files
committed
Use regular torch wheels index for xpu instead of nightly
1 parent e822f59 commit a1437f2

2 files changed

Lines changed: 5 additions & 16 deletions

File tree

tests/test_installer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,25 +72,23 @@ def test_xpu_platform_windows_with_torch_only(monkeypatch):
7272
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
7373
packages = ["torch"]
7474
result = get_install_commands("xpu", packages)
75-
expected_url = "https://download.pytorch.org/whl/test/xpu"
75+
expected_url = "https://download.pytorch.org/whl/xpu"
7676
assert result == [packages + ["--index-url", expected_url]]
7777

7878

79-
def test_xpu_platform_windows_with_torchvision(monkeypatch, capsys):
79+
def test_xpu_platform_windows_with_torchvision(monkeypatch):
8080
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
8181
packages = ["torch", "torchvision"]
8282
result = get_install_commands("xpu", packages)
83-
expected_url = "https://download.pytorch.org/whl/nightly/xpu"
83+
expected_url = "https://download.pytorch.org/whl/xpu"
8484
assert result == [packages + ["--index-url", expected_url]]
85-
captured = capsys.readouterr()
86-
assert "[WARNING]" in captured.out
8785

8886

8987
def test_xpu_platform_linux(monkeypatch):
9088
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
9189
packages = ["torch", "torchvision"]
9290
result = get_install_commands("xpu", packages)
93-
expected_url = "https://download.pytorch.org/whl/test/xpu"
91+
expected_url = "https://download.pytorch.org/whl/xpu"
9492
triton_index_url = "https://download.pytorch.org/whl"
9593
assert result == [
9694
packages + ["--index-url", expected_url],

torchruntime/installer.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def get_install_commands(torch_platform, packages):
4141
ValueError: If an unsupported platform is provided.
4242
4343
Notes:
44-
- For "xpu" on Windows, if torchvision or torchaudio are included, the function switches to nightly builds.
4544
- For "directml", the "torch-directml" package is returned as part of the installation commands.
4645
- For "ipex", the "intel-extension-for-pytorch" package is returned as part of the installation commands.
4746
- For Windows CUDA, the function also installs "triton-windows" (for torch.compile and Triton kernels).
@@ -69,15 +68,7 @@ def get_install_commands(torch_platform, packages):
6968
return cmds
7069

7170
if torch_platform == "xpu":
72-
if os_name == "Windows" and ("torchvision" in packages or "torchaudio" in packages):
73-
print(
74-
f"[WARNING] The preview build of 'xpu' on Windows currently only supports torch, not torchvision/torchaudio. "
75-
f"torchruntime will instead use the nightly build, to get the 'xpu' version of torchaudio and torchvision as well. "
76-
f"Please contact torchruntime if this is no longer accurate: {CONTACT_LINK}"
77-
)
78-
index_url = f"https://download.pytorch.org/whl/nightly/{torch_platform}"
79-
else:
80-
index_url = f"https://download.pytorch.org/whl/test/{torch_platform}"
71+
index_url = f"https://download.pytorch.org/whl/{torch_platform}"
8172

8273
cmds = [packages + ["--index-url", index_url]]
8374
if os_name == "Linux":

0 commit comments

Comments
 (0)