Skip to content

Commit 5592104

Browse files
committed
Merge remote-tracking branch 'origin/master' into parallel_runner
2 parents 0225f42 + 04176b3 commit 5592104

6 files changed

Lines changed: 16 additions & 11 deletions

File tree

doc/requirements.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ matplotlib-inline==0.1.7 ; python_version >= "3.9" and python_version < "3.15"
3434
mistune==3.1.2 ; python_version >= "3.9" and python_version < "3.15"
3535
natsort==8.4.0 ; python_version >= "3.9" and python_version < "3.15"
3636
nbclient==0.10.2 ; python_version >= "3.9" and python_version < "3.15"
37-
nbconvert==7.16.6 ; python_version >= "3.9" and python_version < "3.15"
37+
nbconvert==7.17.0 ; python_version >= "3.9" and python_version < "3.15"
3838
nbformat==5.10.4 ; python_version >= "3.9" and python_version < "3.15"
3939
nbsphinx==0.9.7 ; python_version >= "3.9" and python_version < "3.15"
4040
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.15"
@@ -49,15 +49,15 @@ prompt-toolkit==3.0.50 ; python_version >= "3.9" and python_version < "3.15"
4949
ptyprocess==0.7.0 ; python_version >= "3.9" and python_version < "3.15" and sys_platform != "win32"
5050
pure-eval==0.2.3 ; python_version >= "3.9" and python_version < "3.15"
5151
pycparser==2.22 ; python_version >= "3.9" and python_version < "3.15" and implementation_name == "pypy"
52-
pygments==2.19.1 ; python_version >= "3.9" and python_version < "3.15"
52+
pygments==2.20.0 ; python_version >= "3.9" and python_version < "3.15"
5353
pytest==8.3.5 ; python_version >= "3.9" and python_version < "3.15"
5454
python-constraint2==2.1.0 ; python_version >= "3.9" and python_version < "3.15"
5555
python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.15"
5656
pytz==2025.1 ; python_version >= "3.9" and python_version < "3.15"
5757
pywin32==308 ; sys_platform == "win32" and platform_python_implementation != "PyPy" and python_version >= "3.9" and python_version < "3.15"
5858
pyzmq==26.2.1 ; python_version >= "3.9" and python_version < "3.15"
5959
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.15"
60-
requests==2.32.4 ; python_version >= "3.9" and python_version < "3.15"
60+
requests==2.33.0 ; python_version >= "3.9" and python_version < "3.15"
6161
rpds-py==0.23.1 ; python_version >= "3.9" and python_version < "3.15"
6262
scikit-learn==1.6.1 ; python_version >= "3.9" and python_version < "3.15"
6363
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.15"
@@ -78,7 +78,7 @@ stack-data==0.6.3 ; python_version >= "3.9" and python_version < "3.15"
7878
threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.15"
7979
tinycss2==1.4.0 ; python_version >= "3.9" and python_version < "3.15"
8080
tomli==2.2.1 ; python_version >= "3.9" and python_version < "3.15"
81-
tornado==6.5.1 ; python_version >= "3.9" and python_version < "3.15"
81+
tornado==6.5.5 ; python_version >= "3.9" and python_version < "3.15"
8282
traitlets==5.14.3 ; python_version >= "3.9" and python_version < "3.15"
8383
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.15"
8484
tzdata==2025.1 ; python_version >= "3.9" and python_version < "3.15"

doc/requirements_test.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,9 @@ ptyprocess==0.7.0 ; python_version >= "3.10" and python_version < "4" and (os_na
311311
pure-eval==0.2.3 ; python_version >= "3.10" and python_version < "4" \
312312
--hash=sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0 \
313313
--hash=sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42
314-
pygments==2.19.1 ; python_version >= "3.10" and python_version < "4" \
315-
--hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \
316-
--hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c
314+
pygments==2.20.0 ; python_version >= "3.10" and python_version < "4" \
315+
--hash=sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f \
316+
--hash=sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176
317317
pyproject-hooks==1.2.0 ; python_version >= "3.10" and python_version < "4" \
318318
--hash=sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8 \
319319
--hash=sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913

kernel_tuner/accuracy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _find_bfloat16_if_available():
9696
+ "please install either the package `ml_dtypes`, `jax`, or `tensorflow`"
9797
)
9898

99-
return None
99+
return dtype
100100

101101

102102
def _to_float_dtype(x: str) -> np.dtype:

kernel_tuner/strategies/dual_annealing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""The strategy that uses the dual annealing optimization method."""
22
import scipy.optimize
3+
import numpy as np
34

45
from kernel_tuner.util import StopCriterionReached
56
from kernel_tuner.searchspace import Searchspace
@@ -16,7 +17,7 @@ def tune(searchspace: Searchspace, runner, tuning_options):
1617
method, max_fevals = common.get_options(tuning_options.strategy_options, _options)
1718

1819
#scale variables in x to make 'eps' relevant for multiple variables
19-
cost_func = CostFunc(searchspace, tuning_options, runner, scaling=True)
20+
cost_func = CostFunc(searchspace, tuning_options, runner, scaling=True, invalid_value=np.inf)
2021

2122
bounds, x0, _ = cost_func.get_bounds_x0_eps()
2223

kernel_tuner/strategies/simulated_annealing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def acceptance_prob(old_cost, new_cost, T):
9696
# if start pos is not valid, always move
9797
if isinstance(old_cost, ErrorConfig):
9898
res = 1.0
99-
# if we have found a valid ps before, never move to nonvalid pos
99+
# if we have found a valid pos before, never move to nonvalid pos
100100
elif isinstance(new_cost, ErrorConfig):
101101
res = 0.0
102102
# always move if new cost is better
@@ -108,7 +108,7 @@ def acceptance_prob(old_cost, new_cost, T):
108108
abs_diff = old_cost - new_cost
109109

110110
# relative to abs(old_cost), as the cost might be negative
111-
rel_diff = abs_diff / np.abs(old_cost)
111+
rel_diff = abs_diff / (np.abs(old_cost) if old_cost != 0.0 else 1e-20)
112112

113113
# exponential decay
114114
res = np.exp(rel_diff / T)

kernel_tuner/utils/nvcuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def cuda_error_check(error):
5656
if error != nvrtc.nvrtcResult.NVRTC_SUCCESS:
5757
_, desc = nvrtc.nvrtcGetErrorString(error)
5858
raise RuntimeError(f"NVRTC error: {desc.decode()}")
59+
elif isinstance(error, tuple) and len(error) > 0:
60+
cuda_error_check(error[0])
61+
else:
62+
raise RuntimeError(f"unknown error type returned by CUDA: {error!r} (type: {type(error).__name__})")
5963

6064

6165
def to_valid_nvrtc_gpu_arch_cc(compute_capability: str) -> str:

0 commit comments

Comments
 (0)