Skip to content

Commit c749923

Browse files
authored
Cache: Support classes without __name__ attribute (#865)
* Cache: Support classes without __name__ attribute * fix test * test with cloudpickle * improve test coverage * fixes
1 parent 5e417d4 commit c749923

3 files changed

Lines changed: 49 additions & 1 deletion

File tree

src/executorlib/standalone/serialize.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def serialize_funct(
7575
"kwargs": fn_kwargs,
7676
}
7777
)
78-
task_key = fn.__name__ + _get_hash(binary=binary_all)
78+
task_key = _get_function_name(fn=fn) + _get_hash(binary=binary_all)
7979
data = {
8080
"fn": fn,
8181
"args": fn_args,
@@ -99,3 +99,10 @@ def _get_hash(binary: bytes) -> str:
9999
# Remove specification of jupyter kernel from hash to be deterministic
100100
binary_no_ipykernel = re.sub(b"(?<=/ipykernel_)(.*)(?=/)", b"", binary)
101101
return str(hashlib.md5(binary_no_ipykernel).hexdigest())
102+
103+
104+
def _get_function_name(fn: Callable) -> str:
105+
if hasattr(fn, "__name__"):
106+
return fn.__name__
107+
else:
108+
return str(fn).split()[0].split(".")[-1]

tests/test_singlenodeexecutor_cache.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ def get_error(a):
1717
raise ValueError(a)
1818

1919

20+
class AddClass:
21+
def __call__(self, a, b):
22+
return a+b
23+
24+
2025
@unittest.skipIf(
2126
skip_h5py_test, "h5py is not installed, so the h5io tests are skipped."
2227
)
@@ -34,6 +39,21 @@ def test_cache_data(self):
3439
sum([sum(c["input_args"][0]) for c in cache_lst]), sum(result_lst)
3540
)
3641

42+
def test_cache_data_class(self):
43+
cache_directory = os.path.abspath("executorlib_cache")
44+
with SingleNodeExecutor(cache_directory=cache_directory) as exe:
45+
self.assertTrue(exe)
46+
cloudpickle_register(ind=1)
47+
add_instance = AddClass()
48+
future_lst = [exe.submit(add_instance, a=i, b=i) for i in range(1, 4)]
49+
result_lst = [f.result() for f in future_lst]
50+
51+
cache_lst = get_cache_data(cache_directory=cache_directory)
52+
self.assertEqual(sum([c["output"] for c in cache_lst]), sum(result_lst))
53+
self.assertEqual(
54+
sum([sum([c["input_kwargs"]["a"], c["input_kwargs"]["b"]]) for c in cache_lst]), sum(result_lst)
55+
)
56+
3757
def test_cache_key(self):
3858
cache_directory = os.path.abspath("executorlib_cache")
3959
with SingleNodeExecutor(cache_directory=cache_directory) as exe:

tests/test_standalone_serialize.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import unittest
2+
from executorlib.standalone.serialize import _get_function_name
3+
4+
5+
def my_function(a: int, b: int) -> int:
6+
return a + b
7+
8+
9+
class MyClass:
10+
def __call__(self, a: int, b: int) -> int:
11+
return a + b
12+
13+
14+
class TestSerialization(unittest.TestCase):
15+
def test_serialization(self):
16+
fn = _get_function_name(fn=my_function)
17+
self.assertEqual(fn, "my_function")
18+
fn = _get_function_name(fn=MyClass())
19+
self.assertEqual(fn, "MyClass")
20+
fn = _get_function_name(fn=None)
21+
self.assertEqual(fn, "None")

0 commit comments

Comments
 (0)