Skip to content

Commit 2fe5ba1

Browse files
[feature] Add support for future objects in input dictionaries (#568)
* [feature] Add support for future objects in input dictionaries * extend plotting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add additional test for plotting * fix flux test * more fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9f232ff commit 2fe5ba1

5 files changed

Lines changed: 74 additions & 13 deletions

File tree

executorlib/interactive/shared.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,8 @@ def get_result(arg: Union[list[Future], Future]) -> Any:
483483
return arg.result()
484484
elif isinstance(arg, list):
485485
return [get_result(arg=el) for el in arg]
486+
elif isinstance(arg, dict):
487+
return {k: get_result(arg=v) for k, v in arg.items()}
486488
else:
487489
return arg
488490

@@ -510,6 +512,8 @@ def find_future_in_list(lst):
510512
future_lst.append(el)
511513
elif isinstance(el, list):
512514
find_future_in_list(lst=el)
515+
elif isinstance(el, dict):
516+
find_future_in_list(lst=el.values())
513517

514518
find_future_in_list(lst=task_dict["args"])
515519
find_future_in_list(lst=task_dict["kwargs"].values())

executorlib/standalone/plot.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,28 @@ def add_element(arg, link_to, label=""):
3939
"label": label,
4040
}
4141
)
42-
elif isinstance(arg, list) and all(isinstance(a, Future) for a in arg):
43-
for a in arg:
44-
add_element(arg=a, link_to=link_to, label=label)
42+
elif isinstance(arg, list) and any(isinstance(a, Future) for a in arg):
43+
lst_no_future = [a if not isinstance(a, Future) else "$" for a in arg]
44+
node_id = len(node_lst)
45+
node_lst.append(
46+
{"name": str(lst_no_future), "id": node_id, "shape": "circle"}
47+
)
48+
edge_lst.append({"start": node_id, "end": link_to, "label": label})
49+
for i, a in enumerate(arg):
50+
if isinstance(a, Future):
51+
add_element(arg=a, link_to=node_id, label="ind: " + str(i))
52+
elif isinstance(arg, dict) and any(isinstance(a, Future) for a in arg.values()):
53+
dict_no_future = {
54+
kt: vt if not isinstance(vt, Future) else "$" for kt, vt in arg.items()
55+
}
56+
node_id = len(node_lst)
57+
node_lst.append(
58+
{"name": str(dict_no_future), "id": node_id, "shape": "circle"}
59+
)
60+
edge_lst.append({"start": node_id, "end": link_to, "label": label})
61+
for kt, vt in arg.items():
62+
if isinstance(vt, Future):
63+
add_element(arg=vt, link_to=node_id, label="key: " + kt)
4564
else:
4665
node_id = len(node_lst)
4766
node_lst.append({"name": str(arg), "id": node_id, "shape": "circle"})
@@ -92,6 +111,11 @@ def convert_arg(arg, future_hash_inverse_dict):
92111
convert_arg(arg=a, future_hash_inverse_dict=future_hash_inverse_dict)
93112
for a in arg
94113
]
114+
elif isinstance(arg, dict):
115+
return {
116+
k: convert_arg(arg=v, future_hash_inverse_dict=future_hash_inverse_dict)
117+
for k, v in arg.items()
118+
}
95119
else:
96120
return arg
97121

tests/test_dependencies_executor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def merge(lst):
3838
return sum(lst)
3939

4040

41+
def return_input_dict(input_dict):
42+
return input_dict
43+
44+
4145
def raise_error():
4246
raise RuntimeError
4347

@@ -130,6 +134,14 @@ def test_many_to_one(self):
130134
)
131135
self.assertEqual(future_sum.result(), 15)
132136

137+
def test_future_input_dict(self):
138+
with SingleNodeExecutor() as exe:
139+
fs = exe.submit(
140+
return_input_dict,
141+
input_dict={"a": exe.submit(sum, [2, 2])},
142+
)
143+
self.assertEqual(fs.result()["a"], 4)
144+
133145

134146
class TestExecutorErrors(unittest.TestCase):
135147
def test_block_allocation_false_one_worker(self):

tests/test_plot_dependency.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def merge(lst):
3939
return sum(lst)
4040

4141

42+
def return_input_dict(input_dict):
43+
return input_dict
44+
45+
4246
@unittest.skipIf(
4347
skip_graphviz_test,
4448
"graphviz is not installed, so the plot_dependency_graph tests are skipped.",
@@ -124,8 +128,25 @@ def test_many_to_one_plot(self):
124128
v: k for k, v in exe._future_hash_dict.items()
125129
},
126130
)
127-
self.assertEqual(len(nodes), 18)
128-
self.assertEqual(len(edges), 21)
131+
self.assertEqual(len(nodes), 19)
132+
self.assertEqual(len(edges), 22)
133+
134+
def test_future_input_dict(self):
135+
with SingleNodeExecutor(plot_dependency_graph=True) as exe:
136+
exe.submit(
137+
return_input_dict,
138+
input_dict={"a": exe.submit(sum, [2, 2])},
139+
)
140+
self.assertEqual(len(exe._future_hash_dict), 2)
141+
self.assertEqual(len(exe._task_hash_dict), 2)
142+
nodes, edges = generate_nodes_and_edges(
143+
task_hash_dict=exe._task_hash_dict,
144+
future_hash_inverse_dict={
145+
v: k for k, v in exe._future_hash_dict.items()
146+
},
147+
)
148+
self.assertEqual(len(nodes), 4)
149+
self.assertEqual(len(edges), 3)
129150

130151

131152
@unittest.skipIf(
@@ -197,8 +218,8 @@ def test_many_to_one_plot(self):
197218
v: k for k, v in exe._future_hash_dict.items()
198219
},
199220
)
200-
self.assertEqual(len(nodes), 18)
201-
self.assertEqual(len(edges), 21)
221+
self.assertEqual(len(nodes), 19)
222+
self.assertEqual(len(edges), 22)
202223

203224

204225
@unittest.skipIf(
@@ -266,5 +287,5 @@ def test_many_to_one_plot(self):
266287
v: k for k, v in exe._future_hash_dict.items()
267288
},
268289
)
269-
self.assertEqual(len(nodes), 18)
270-
self.assertEqual(len(edges), 21)
290+
self.assertEqual(len(nodes), 19)
291+
self.assertEqual(len(edges), 22)

tests/test_plot_dependency_flux.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def test_many_to_one_plot(self):
106106
v: k for k, v in exe._future_hash_dict.items()
107107
},
108108
)
109-
self.assertEqual(len(nodes), 18)
110-
self.assertEqual(len(edges), 21)
109+
self.assertEqual(len(nodes), 19)
110+
self.assertEqual(len(edges), 22)
111111

112112

113113
@unittest.skipIf(
@@ -175,5 +175,5 @@ def test_many_to_one_plot(self):
175175
v: k for k, v in exe._future_hash_dict.items()
176176
},
177177
)
178-
self.assertEqual(len(nodes), 18)
179-
self.assertEqual(len(edges), 21)
178+
self.assertEqual(len(nodes), 19)
179+
self.assertEqual(len(edges), 22)

0 commit comments

Comments
 (0)