@@ -1196,6 +1196,112 @@ def forward(self, a, *args, **kwargs):
11961196 )
11971197 torch .export .export (model , args , kwargs = kwargs , dynamic_shapes = ds )
11981198
1199+ def test_remove_inputs_kwargs (self ):
1200+ """Test that remove_inputs removes a kwarg from the observer info."""
1201+
1202+ class Model (torch .nn .Module ):
1203+ def forward (self , x , y , z = None ):
1204+ r = x + y
1205+ if z is not None :
1206+ r += z
1207+ return r
1208+
1209+ inputs = [
1210+ dict (x = torch .randn ((5 , 6 )), y = torch .randn ((1 , 6 )), z = torch .randn ((5 , 6 ))),
1211+ dict (x = torch .randn ((7 , 7 )), y = torch .randn ((1 , 7 )), z = torch .randn ((7 , 7 ))),
1212+ dict (x = torch .randn ((7 , 8 )), y = torch .randn ((1 , 8 )), z = torch .randn ((7 , 8 ))),
1213+ ]
1214+
1215+ model = Model ()
1216+ observer = InputObserver ()
1217+ with observer (model ):
1218+ for kwargs in inputs :
1219+ model (** kwargs )
1220+ self .assertEqual (len (observer .info ), 3 )
1221+
1222+ cst = torch .export .Dim .DYNAMIC
1223+ ds = observer .infer_dynamic_shapes ()
1224+ self .assertIn ("z" , ds )
1225+ self .assertIn ("x" , ds )
1226+ self .assertIn ("y" , ds )
1227+
1228+ # Remove z input
1229+ observer .remove_inputs (["z" ])
1230+
1231+ ds_after = observer .infer_dynamic_shapes ()
1232+ self .assertNotIn ("z" , ds_after )
1233+ self .assertIn ("x" , ds_after )
1234+ self .assertIn ("y" , ds_after )
1235+ self .assertEqual (dict (x = {0 : cst , 1 : cst }, y = {1 : cst }), ds_after )
1236+
1237+ args_after = observer .infer_arguments ()
1238+ self .assertIsInstance (args_after , dict )
1239+ self .assertNotIn ("z" , args_after )
1240+ self .assertIn ("x" , args_after )
1241+ self .assertIn ("y" , args_after )
1242+
1243+ def test_remove_inputs_multiple_kwargs (self ):
1244+ """Test that remove_inputs removes multiple kwargs at once."""
1245+
1246+ class Model (torch .nn .Module ):
1247+ def forward (self , x , y , z = None , w = None ):
1248+ r = x + y
1249+ if z is not None :
1250+ r += z
1251+ if w is not None :
1252+ r += w
1253+ return r
1254+
1255+ inputs = [
1256+ dict (
1257+ x = torch .randn ((5 , 6 )),
1258+ y = torch .randn ((1 , 6 )),
1259+ z = torch .randn ((5 , 6 )),
1260+ w = torch .randn ((1 , 6 )),
1261+ ),
1262+ dict (
1263+ x = torch .randn ((6 , 7 )),
1264+ y = torch .randn ((1 , 7 )),
1265+ z = torch .randn ((6 , 7 )),
1266+ w = torch .randn ((1 , 7 )),
1267+ ),
1268+ dict (
1269+ x = torch .randn ((7 , 8 )),
1270+ y = torch .randn ((1 , 8 )),
1271+ z = torch .randn ((7 , 8 )),
1272+ w = torch .randn ((1 , 8 )),
1273+ ),
1274+ ]
1275+
1276+ model = Model ()
1277+ observer = InputObserver ()
1278+ with observer (model ):
1279+ for kwargs in inputs :
1280+ model (** kwargs )
1281+ self .assertEqual (len (observer .info ), 3 )
1282+
1283+ cst = torch .export .Dim .DYNAMIC
1284+ ds = observer .infer_dynamic_shapes ()
1285+ self .assertIn ("z" , ds )
1286+ self .assertIn ("w" , ds )
1287+
1288+ # Remove z and w inputs
1289+ observer .remove_inputs (["z" , "w" ])
1290+
1291+ ds_after = observer .infer_dynamic_shapes ()
1292+ self .assertNotIn ("z" , ds_after )
1293+ self .assertNotIn ("w" , ds_after )
1294+ self .assertIn ("x" , ds_after )
1295+ self .assertIn ("y" , ds_after )
1296+ self .assertEqual (dict (x = {0 : cst , 1 : cst }, y = {1 : cst }), ds_after )
1297+
1298+ args_after = observer .infer_arguments ()
1299+ self .assertIsInstance (args_after , dict )
1300+ self .assertNotIn ("z" , args_after )
1301+ self .assertNotIn ("w" , args_after )
1302+ self .assertIn ("x" , args_after )
1303+ self .assertIn ("y" , args_after )
1304+
11991305
12001306if __name__ == "__main__" :
12011307 unittest .main (verbosity = 2 )
0 commit comments