-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconvert_to_onnx.py
More file actions
62 lines (51 loc) · 2.07 KB
/
convert_to_onnx.py
File metadata and controls
62 lines (51 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# convert_to_onnx.py
import os
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers.onnx import export
import mlflow
import mlflow.onnx
from mlflow.models.signature import infer_signature
# from azure.ai.ml import MLClient
# from azure.identity import DefaultAzureCredential
# ------------------------------------------------------------------
# USER SETTINGS – edit these
# ------------------------------------------------------------------
MODEL_ID = "ibm-esa-geospatial/TerraMind-1.0-base" # <-- your HF repo
OUTPUT_DIR = Path("onnx_model")
OPSET = 14
# ------------------------------------------------------------------
def main():
# Start MLflow run for tracking
with mlflow.start_run() as run:
os.makedirs(OUTPUT_DIR, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
# Log parameters
mlflow.log_param("model_id", MODEL_ID)
mlflow.log_param("opset", OPSET)
# Export to ONNX
onnx_path = OUTPUT_DIR / "model.onnx"
export(
preprocessor=tokenizer,
model=model,
config="default", # works for most HF models
opset=OPSET,
output=onnx_path
)
print(f"ONNX model saved to {onnx_path}")
# Create a dummy input for signature inference
dummy_input = tokenizer("This is a test sentence.", return_tensors="pt")
dummy_output = model(**dummy_input)
# Infer model signature
signature = infer_signature(dummy_input, dummy_output)
# Log the ONNX model using MLflow
mlflow.onnx.log_model(
onnx_model=str(onnx_path),
artifact_path="onnx_model",
signature=signature,
registered_model_name="terramind-onnx-model" # Register the model for easy deployment
)
print(f"MLflow Model logged and registered as 'terramind-onnx-model' in run {run.info.run_id}")
if __name__ == "__main__":
main()