diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 53b18ab67..dd8a44869 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -20,7 +20,7 @@ jobs: uses: actions/setup-python@v4 with: cache: 'pip' - python-version: "3.9" + python-version: "3.10" - name: Install tox-gh run: pip install tox-gh diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 2d905ec5b..40bd04f61 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -32,6 +32,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install tox-gh run: pip install tox-gh + - name: Install quantus with test dependencies + run: pip install .[tests] - name: Setup test environment run: tox run --notest - name: Test with PyTest diff --git a/pyproject.toml b/pyproject.toml index 2e829fa10..8d25ec57a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,10 +51,15 @@ dynamic = ["version"] # [project.optional-dependencies] torch = [ + # Mac/Windows or unknown platform fallback "torch>=1.13.1; sys_platform != 'linux'", "torchvision>=0.15.1; sys_platform != 'linux'", - "torch>=1.13.1, <2.0.0; sys_platform == 'linux' and python_version <= '3.10'", - "torchvision>=0.14.0, <0.15.1; sys_platform == 'linux' and python_version <= '3.10'", + + # Linux + Python 3.8–3.10 → use torch 1.x + "torch>=1.13.1, <2.0.0; sys_platform == 'linux' and python_version < '3.11'", + "torchvision>=0.14.0, <0.15.1; sys_platform == 'linux' and python_version < '3.11'", + + # Linux + Python 3.11+ → use torch 2.x "torch>=2.0.0; sys_platform == 'linux' and python_version >= '3.11'", "torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'", ] @@ -62,7 +67,8 @@ tensorflow = [ # 2.16 is shipped without keras "tensorflow<2.16.0", # keras V3 broke everything - "keras<3" + "keras<3", + "numpy<2" ] captum = [ "quantus[torch]", @@ -78,10 +84,10 @@ zennit = [ ] transformers = [ "quantus[torch]", - "transformers>=4.38.2", + "transformers>=4.38.2", #"transformers<4.38.0" ] full = [ - "quantus[captum,tf-explain,zennit,transformers]" + "quantus[captum,tf-explain,zennit,transformers,torch]" ] tests = [ "coverage>=7.2.3", @@ -91,6 +97,7 @@ tests = [ "pytest-lazy-fixture>=0.6.3", "pytest-mock==3.10.0", "pytest_xdist", + "numpy<2", "quantus[full]" ] diff --git a/quantus/__init__.py b/quantus/__init__.py index 407f5c733..2f6ecd0d3 100644 --- a/quantus/__init__.py +++ b/quantus/__init__.py @@ -5,7 +5,7 @@ # Quantus project URL: . # Set the correct version. -__version__ = "0.5.3" +__version__ = "0.6.0" # Expose quantus.evaluate to the user. from quantus.evaluation import evaluate diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index d9197cbd7..544d926f0 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -375,6 +375,6 @@ def evaluate_batch( y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] results.append(y_pred_perturb) - results = np.stack(results, 1, dtype=np.float64) + results = np.stack(results, 1).astype(np.float64) return results diff --git "a/tests/assets/Icon\r" "b/tests/assets/Icon\r" deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/conftest.py b/tests/conftest.py index bde949ac9..11b7362f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -239,23 +239,36 @@ def load_mnist_model_softmax(): @pytest.fixture(scope="session", autouse=False) def load_hf_distilbert_sequence_classifier(): - """ - TODO - """ - DISTILBERT_BASE = "distilbert-base-uncased" - model = AutoModelForSequenceClassification.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/") - return model + try: + import torch + except ImportError: + pytest.skip("Skipping because torch is not available.") + + try: + from transformers import AutoModelForSequenceClassification + DISTILBERT_BASE = "distilbert-base-uncased" + return AutoModelForSequenceClassification.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/") + except Exception as e: + pytest.skip(f"Skipping because model loading failed: {e}") + @pytest.fixture(scope="session", autouse=False) def dummy_hf_tokenizer(): - """ - TODO - """ - DISTILBERT_BASE = "distilbert-base-uncased" - REFERENCE_TEXT = "The quick brown fox jumps over the lazy dog" - tokenizer = AutoTokenizer.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/", clean_up_tokenization_spaces=True) - return tokenizer(REFERENCE_TEXT, return_tensors="pt") + try: + import torch + except ImportError: + pytest.skip("Skipping because torch is not available.") + + try: + from transformers import AutoTokenizer + DISTILBERT_BASE = "distilbert-base-uncased" + REFERENCE_TEXT = "The quick brown fox jumps over the lazy dog" + tokenizer = AutoTokenizer.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/", clean_up_tokenization_spaces=True) + return tokenizer(REFERENCE_TEXT, return_tensors="pt") + except Exception as e: + pytest.skip(f"Skipping because tokenizer loading failed: {e}") + @pytest.fixture(scope="session", autouse=True) diff --git a/tests/functions/test_pytorch_model.py b/tests/functions/test_pytorch_model.py index caa64f501..590817609 100644 --- a/tests/functions/test_pytorch_model.py +++ b/tests/functions/test_pytorch_model.py @@ -11,6 +11,12 @@ from scipy.special import softmax from quantus.helpers.model.pytorch_model import PyTorchModel +def torch_available(): + try: + import torch + return True + except ImportError: + return False @pytest.fixture def mock_input_torch_array(): diff --git a/tox.ini b/tox.ini index aa678a730..341a2100d 100644 --- a/tox.ini +++ b/tox.ini @@ -12,12 +12,17 @@ deps = .[tests] pass_env = TF_XLA_FLAGS +commands_pre = + python -c "import sys, platform; print(f'Python: {sys.version_info.major}.{sys.version_info.minor}, Platform: {platform.system()}')" + python -c "import sys; import subprocess; v=sys.version_info; linux=sys.platform=='linux'; needs=torch_install=(linux and (v.major==3 and v.minor in (9,10))); subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'torch<2.0.0']) if needs else None" commands = pytest -s -v {posargs} + [testenv:coverage] description = Run the tests with coverage -base_python = py310 +base_python = + deps = {[testenv]deps} coverage @@ -27,14 +32,14 @@ commands = [testenv:build] description = Build environment -base_python = py310 deps = - . build twine commands = - python3 -m build . - python3 -m twine check ./dist/* --strict + python -c "import shutil; shutil.rmtree('dist', ignore_errors=True)" + python -m build . + python -m twine check ./dist/* --strict + [testenv:lint] description = Check the code style