@@ -538,7 +538,35 @@ def test_constant_of_shape(self):
538538 got = ref .run (None , {"X" : np .array ([2 , 3 ], dtype = np .int64 )})[0 ]
539539 self .assertEqualArray (np .zeros ((2 , 3 ), dtype = np .float32 ), got )
540540
541+ def test_constant_of_shape_value (self ):
542+ onx = (
543+ start ()
544+ .vin ("X" , TensorProto .INT64 , shape = [None , None ])
545+ .ConstantOfShape (value = np .array ([1 ], dtype = np .float32 ))
546+ .vout (shape = [])
547+ .to_onnx ()
548+ )
549+ ref = ReferenceEvaluator (onx )
550+ got = ref .run (None , {"X" : np .array ([2 , 3 ], dtype = np .int64 )})[0 ]
551+ self .assertEqualArray (np .ones ((2 , 3 ), dtype = np .float32 ), got )
552+
553+ def test_slice (self ):
554+ onx = (
555+ start (opset = 18 , ir_version = 9 )
556+ .cst (np .array ([1 ], dtype = np .int64 ), name = "one" )
557+ .cst (np .array ([2 ], dtype = np .int64 ), name = "two" )
558+ .vin ("X" , TensorProto .INT64 , shape = [None , None ])
559+ .ConstantOfShape (value = np .array ([1 ], dtype = np .float32 ))
560+ .rename ("CX" )
561+ .bring ("CX" , "one" , "two" , "one" )
562+ .Slice ()
563+ .vout (shape = [])
564+ .to_onnx ()
565+ )
566+ ref = ReferenceEvaluator (onx )
567+ got = ref .run (None , {"X" : np .array ([2 , 3 ], dtype = np .int64 )})[0 ]
568+ self .assertEqualArray (np .ones ((2 , 1 ), dtype = np .float32 ), got )
569+
541570
542571if __name__ == "__main__" :
543- TestLightApi ().test_add ()
544572 unittest .main (verbosity = 2 )
0 commit comments