Skip to content

Commit 177e0f3

Browse files
authored
Fix device type in conftest.py (#231)
* Change DEVICE to be a torch.device instead of a string * Change cuda to cuda:0
1 parent 926ad3f commit 177e0f3

2 files changed

Lines changed: 9 additions & 7 deletions

File tree

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ with maintainers before implementing major changes.
3232

3333
- If you have access to a cuda-enabled GPU, you should also check that the unit tests pass on it:
3434
```bash
35-
CUBLAS_WORKSPACE_CONFIG=:4096:8 PYTEST_TORCH_DEVICE=cuda pdm run pytest tests/unit
35+
CUBLAS_WORKSPACE_CONFIG=:4096:8 PYTEST_TORCH_DEVICE=cuda:0 pdm run pytest tests/unit
3636
```
3737

3838
- To check that the usage examples from docstrings and `.rst` files are correct, we test their

tests/unit/conftest.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
from pytest import fixture
66

77
try:
8-
DEVICE = os.environ["PYTEST_TORCH_DEVICE"]
8+
_device_str = os.environ["PYTEST_TORCH_DEVICE"]
99
except KeyError:
10-
DEVICE = "cpu" # Default to cpu if environment variable not set
10+
_device_str = "cpu" # Default to cpu if environment variable not set
1111

12-
if DEVICE != "cuda" and DEVICE != "cpu":
13-
raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {DEVICE}")
12+
if _device_str != "cuda:0" and _device_str != "cpu":
13+
raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}")
1414

15-
if DEVICE == "cuda" and not torch.cuda.is_available():
16-
raise ValueError('Requested device "cuda" but cuda is not available.')
15+
if _device_str == "cuda:0" and not torch.cuda.is_available():
16+
raise ValueError('Requested device "cuda:0" but cuda is not available.')
17+
18+
DEVICE = torch.device(_device_str)
1719

1820

1921
@fixture(autouse=True)

0 commit comments

Comments
 (0)