@@ -156,6 +156,7 @@ def __init__(
156156 optimization_options : Optional [OptimizationOptions ] = None ,
157157 args : Optional [List [Any ]] = None ,
158158 verbose : int = 0 ,
159+ ir_version : Optional [int ] = None ,
159160 ):
160161 self .optimization_options = optimization_options or OptimizationOptions ()
161162 self .as_function = as_function
@@ -170,6 +171,7 @@ def __init__(
170171 if isinstance (target_opset_or_existing_proto , int )
171172 else target_opset_or_existing_proto
172173 )
174+ self .ir_version = ir_version
173175 self .nodes = []
174176 self .initializers_dict = {}
175177 self .inputs = []
@@ -186,6 +188,7 @@ def __init__(
186188 ), "input_names must be empty if the input is an existing model."
187189 proto = target_opset_or_existing_proto
188190 self .opsets = {d .domain : d .version for d in proto .opset_import }
191+ self .ir_version = ir_version or target_opset_or_existing_proto .ir_version
189192 self .nodes = list (proto .graph .node )
190193 self .initializers_dict = {i .name : i for i in proto .graph .initializer }
191194 self .initializers_dict .update (
@@ -674,6 +677,8 @@ def to_onnx(
674677 if self .verbose :
675678 print ("[GraphBuilder] onh.make_model" )
676679 model = oh .make_model (graph , opset_imports = opsets )
680+ if self .ir_version :
681+ model .ir_version = self .ir_version
677682 return model
678683
679684 def _check_order_node (self , ind : int , node : NodeProto , existing : Set [str ]):
0 commit comments