Skip to content

Commit e1f0c88

Browse files
xadupreCopilotCopilot
authored
add remove_inputs to InputObserver (#422)
* add remove_inputs to InputObserver * changes * mypy * Add unit tests for `remove_inputs` in `InputObserver` (#423) * Initial plan * Add unit tests for remove_inputs in InputObserver Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> * a few fixes * Update onnx_diagnostic/investigate/input_observer.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix * style * update doc * fix documentation --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 53e8b87 commit e1f0c88

10 files changed

Lines changed: 321 additions & 9 deletions

File tree

CHANGELOGS.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Change Logs
44
0.9.3
55
+++++
66

7+
* :pr:`422`: add remove_inputs to InputObserver
8+
* :pr:`421`: fix a few patches for MoE
9+
710
0.9.2
811
+++++
912

_doc/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def linkcode_resolve(domain, info):
212212

213213
if int(os.environ.get("UNITTEST_GOING", "0")):
214214
sphinx_gallery_conf["ignore_pattern"] = (
215-
".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)|(optimind)).*"
215+
".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)|(optimind)|(export_with_modelbuilder)).*"
216216
)
217217
elif pv.Version(torch.__version__) < pv.Version("2.8"):
218218
sphinx_gallery_conf["ignore_pattern"] = ".*((_oe_)|(dort)|(draft_mode)).*"
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
.. _l-plot-export-model-builder:
3+
4+
Export with ModelBuilder
5+
========================
6+
7+
"""
8+
9+
import sys
10+
import os
11+
import pandas
12+
import torch
13+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
14+
from onnx_diagnostic import doc
15+
from onnx_diagnostic.investigate.input_observer import InputObserver
16+
from onnx_diagnostic.helpers.rt_helper import onnx_generate
17+
from onnx_diagnostic.torch_export_patches import (
18+
register_additional_serialization_functions,
19+
torch_export_patches,
20+
)
21+
from onnx_diagnostic.export.api import to_onnx
22+
23+
24+
def generate_text(
25+
prompt,
26+
model,
27+
tokenizer,
28+
max_length=50,
29+
temperature=0.01,
30+
top_k=50,
31+
top_p=0.95,
32+
do_sample=True,
33+
device="cpu",
34+
):
35+
inputs = tokenizer(prompt, return_tensors="pt")
36+
input_ids = inputs["input_ids"].to(device)
37+
attention_mask = inputs["attention_mask"].to(device)
38+
39+
outputs = model.generate(
40+
input_ids=input_ids,
41+
attention_mask=attention_mask,
42+
max_length=max_length,
43+
temperature=temperature,
44+
top_k=top_k,
45+
top_p=top_p,
46+
do_sample=do_sample,
47+
)
48+
49+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
50+
return generated_text
51+
52+
53+
# %%
54+
# filename for the model
55+
MODEL_NAME = sys.argv[1] if sys.argv and len(sys.argv) > 1 else "arnir0/Tiny-LLM"
56+
cache_dir = "dump_modelbuilder"
57+
os.makedirs(cache_dir, exist_ok=True)
58+
name = MODEL_NAME.replace("/", "_")
59+
filename = os.path.join(cache_dir, f"plot_export_with_modelbuilder_{name}.onnx")
60+
61+
62+
# %%
63+
# Creating the model
64+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
66+
if not os.path.exists(filename):
67+
print(f"-- creating... on {device} into {filename!r}")
68+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
69+
model = model.to(device)
70+
config = model.config
71+
else:
72+
config = AutoConfig.from_pretrained(MODEL_NAME)
73+
74+
75+
# %%
76+
# Capturing inputs/outputs to infer dynamic shapes and arguments
77+
print("-- capturing...")
78+
prompt = "Continue: it rains, what should I do?"
79+
if not os.path.exists(filename):
80+
observer = InputObserver()
81+
with register_additional_serialization_functions(patch_transformers=True), observer(model):
82+
generate_text(prompt, model, tokenizer, device=device)
83+
84+
85+
# %%
86+
# Exporting.
87+
if not os.path.exists(filename):
88+
print("-- exporting...")
89+
observer.remove_inputs(["cache_position", "logits_to_keep", "position_ids"])
90+
ds = observer.infer_dynamic_shapes(set_batch_dimension_for=True)
91+
kwargs = observer.infer_arguments()
92+
93+
with torch_export_patches(patch_transformers=True):
94+
to_onnx(
95+
model,
96+
filename=filename,
97+
kwargs=kwargs,
98+
dynamic_shapes=ds,
99+
exporter="modelbuilder",
100+
)
101+
102+
data = observer.check_discrepancies(filename, progress_bar=True)
103+
print(pandas.DataFrame(data))
104+
105+
# %%
106+
# ONNX Prompt
107+
# +++++++++++
108+
print("-- ONNX prompts...")
109+
inputs = tokenizer(prompt, return_tensors="pt")
110+
input_ids = inputs["input_ids"].to(device)
111+
attention_mask = inputs["attention_mask"].to(device)
112+
113+
onnx_tokens = onnx_generate(
114+
filename,
115+
input_ids=input_ids,
116+
attention_mask=attention_mask,
117+
eos_token_id=config.eos_token_id,
118+
max_new_tokens=50,
119+
)
120+
onnx_generated_text = tokenizer.decode(onnx_tokens, skip_special_tokens=True)
121+
122+
print("-----------------")
123+
print("\n".join(onnx_generated_text))
124+
print("-----------------")
125+
126+
# %%
127+
if os.stat(filename).st_size < 2**14:
128+
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)

_doc/technical/plot_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
tokenizer = AutoTokenizer.from_pretrained(model_id)
4848
else:
4949
model_id = "microsoft/phi-1_5"
50-
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
50+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
5151
tokenizer = AutoTokenizer.from_pretrained(model_id)
5252
config = get_pretrained_config(model_id)
5353
task = task = task_from_id(model_id)

_unittests/ut_investigate/test_input_observer.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,112 @@ def forward(self, a, *args, **kwargs):
11961196
)
11971197
torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=ds)
11981198

1199+
def test_remove_inputs_kwargs(self):
1200+
"""Test that remove_inputs removes a kwarg from the observer info."""
1201+
1202+
class Model(torch.nn.Module):
1203+
def forward(self, x, y, z=None):
1204+
r = x + y
1205+
if z is not None:
1206+
r += z
1207+
return r
1208+
1209+
inputs = [
1210+
dict(x=torch.randn((5, 6)), y=torch.randn((1, 6)), z=torch.randn((5, 6))),
1211+
dict(x=torch.randn((7, 7)), y=torch.randn((1, 7)), z=torch.randn((7, 7))),
1212+
dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)), z=torch.randn((7, 8))),
1213+
]
1214+
1215+
model = Model()
1216+
observer = InputObserver()
1217+
with observer(model):
1218+
for kwargs in inputs:
1219+
model(**kwargs)
1220+
self.assertEqual(len(observer.info), 3)
1221+
1222+
cst = torch.export.Dim.DYNAMIC
1223+
ds = observer.infer_dynamic_shapes()
1224+
self.assertIn("z", ds)
1225+
self.assertIn("x", ds)
1226+
self.assertIn("y", ds)
1227+
1228+
# Remove z input
1229+
observer.remove_inputs(["z"])
1230+
1231+
ds_after = observer.infer_dynamic_shapes()
1232+
self.assertNotIn("z", ds_after)
1233+
self.assertIn("x", ds_after)
1234+
self.assertIn("y", ds_after)
1235+
self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), ds_after)
1236+
1237+
args_after = observer.infer_arguments()
1238+
self.assertIsInstance(args_after, dict)
1239+
self.assertNotIn("z", args_after)
1240+
self.assertIn("x", args_after)
1241+
self.assertIn("y", args_after)
1242+
1243+
def test_remove_inputs_multiple_kwargs(self):
1244+
"""Test that remove_inputs removes multiple kwargs at once."""
1245+
1246+
class Model(torch.nn.Module):
1247+
def forward(self, x, y, z=None, w=None):
1248+
r = x + y
1249+
if z is not None:
1250+
r += z
1251+
if w is not None:
1252+
r += w
1253+
return r
1254+
1255+
inputs = [
1256+
dict(
1257+
x=torch.randn((5, 6)),
1258+
y=torch.randn((1, 6)),
1259+
z=torch.randn((5, 6)),
1260+
w=torch.randn((1, 6)),
1261+
),
1262+
dict(
1263+
x=torch.randn((6, 7)),
1264+
y=torch.randn((1, 7)),
1265+
z=torch.randn((6, 7)),
1266+
w=torch.randn((1, 7)),
1267+
),
1268+
dict(
1269+
x=torch.randn((7, 8)),
1270+
y=torch.randn((1, 8)),
1271+
z=torch.randn((7, 8)),
1272+
w=torch.randn((1, 8)),
1273+
),
1274+
]
1275+
1276+
model = Model()
1277+
observer = InputObserver()
1278+
with observer(model):
1279+
for kwargs in inputs:
1280+
model(**kwargs)
1281+
self.assertEqual(len(observer.info), 3)
1282+
1283+
cst = torch.export.Dim.DYNAMIC
1284+
ds = observer.infer_dynamic_shapes()
1285+
self.assertIn("z", ds)
1286+
self.assertIn("w", ds)
1287+
1288+
# Remove z and w inputs
1289+
observer.remove_inputs(["z", "w"])
1290+
1291+
ds_after = observer.infer_dynamic_shapes()
1292+
self.assertNotIn("z", ds_after)
1293+
self.assertNotIn("w", ds_after)
1294+
self.assertIn("x", ds_after)
1295+
self.assertIn("y", ds_after)
1296+
self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), ds_after)
1297+
1298+
args_after = observer.infer_arguments()
1299+
self.assertIsInstance(args_after, dict)
1300+
self.assertNotIn("z", args_after)
1301+
self.assertNotIn("w", args_after)
1302+
self.assertIn("x", args_after)
1303+
self.assertIn("y", args_after)
1304+
11991305

12001306
if __name__ == "__main__":
12011307
unittest.main(verbosity=2)

_unittests/ut_tasks/try_tasks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def test_text_generation_phi4_moe(self):
263263
model = AutoModelForCausalLM.from_pretrained(
264264
model_path,
265265
device_map="cuda",
266-
torch_dtype="auto",
266+
dtype="auto",
267267
trust_remote_code=True,
268268
# if you do not use Ampere or later GPUs, change attention to "eager"
269269
# _attn_implementation='flash_attention_2',
@@ -352,7 +352,7 @@ def test_imagetext2text_generation_idefics(self):
352352
mid = "HuggingFaceM4/tiny-random-idefics"
353353
processor = AutoProcessor.from_pretrained(mid)
354354
model = IdeficsForVisionText2Text.from_pretrained(
355-
mid, torch_dtype=torch.bfloat16, device_map="auto"
355+
mid, dtype=torch.bfloat16, device_map="auto"
356356
)
357357

358358
prompt = [
@@ -699,7 +699,7 @@ def test_falcon_mamba_dev(self):
699699
"text-generation",
700700
model=model,
701701
tokenizer=tokenizer,
702-
torch_dtype=torch.bfloat16,
702+
dtype=torch.bfloat16,
703703
trust_remote_code=True,
704704
device_map="auto",
705705
)
@@ -736,7 +736,7 @@ def test_falcon_mamba_7b(self):
736736
"text-generation",
737737
model=model,
738738
tokenizer=tokenizer,
739-
torch_dtype=torch.bfloat16,
739+
dtype=torch.bfloat16,
740740
trust_remote_code=True,
741741
device_map="auto",
742742
)
@@ -802,7 +802,7 @@ def test_text_to_image(self):
802802
from diffusers import StableDiffusionPipeline
803803

804804
model_id = "diffusers/tiny-torch-full-checker" # "stabilityai/stable-diffusion-2"
805-
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(
805+
pipe = StableDiffusionPipeline.from_pretrained(model_id, dtype=torch.float16).to(
806806
"cuda"
807807
)
808808

_unittests/ut_xrun_doc/test_documentation_examples.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def add_test_methods(cls):
8585

8686
# transformers
8787

88+
if not reason and name in {"plot_export_with_modelbuilder.py"}:
89+
reason = "downloading"
90+
8891
if (
8992
not reason
9093
and name in {"plot_export_tiny_llm.py", "plot_export_tiny_llm_patched.py"}

onnx_diagnostic/ci_models/export_phi4_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ def main(
794794
model_id,
795795
config=config,
796796
trust_remote_code=True,
797-
torch_dtype=torch_dtype,
797+
dtype=torch_dtype,
798798
device_map=device,
799799
attn_implementation="sdpa",
800800
).eval()

onnx_diagnostic/export/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from collections.abc import Mapping, Iterable
66
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
7+
import onnx
78
import torch
89
from .dynamic_shapes import ModelInputs
910
from .onnx_plug import EagerDirectReplacementWithOnnx
@@ -312,10 +313,14 @@ def to_onnx(
312313
mod,
313314
precision=str(first_float[0].dtype).split(".")[-1],
314315
execution_provider="cuda" if first.is_cuda else "cpu",
315-
cache_dir=os.path.dirname(filename),
316+
cache_dir=os.path.dirname(filename) or ".",
316317
**(exporter_kwargs or {}),
317318
)
318319
save_model_builder(onx, os.path.dirname(filename))
320+
temp_filename = os.path.join(os.path.dirname(filename), "model.onnx")
321+
# renaming
322+
onx = onnx.load(temp_filename, load_external_data=True)
323+
onnx.save(onx, filename, save_as_external_data=True)
319324
return onx
320325

321326
raise ValueError(f"Unknown exporter={exporter!r}")

0 commit comments

Comments
 (0)