@@ -11,35 +11,16 @@ def get_mask(self, x, selection_rate=0.0):
1111 num_keep = num_patches - num_mask
1212 noise_random = torch .rand (batch_size , num_patches , device = device )
1313 ids_shuffle = torch .argsort (noise_random , dim = 1 )
14- ids_restore = torch .argsort (ids_shuffle , dim = 1 )
1514 ids_keep = ids_shuffle [:, :num_keep ]
16- ids_mask = ids_shuffle [:, num_keep :]
17- mask = torch .ones ((batch_size , num_patches ), device = device , dtype = torch .bool )
18- mask .scatter_ (1 , ids_keep , False )
19- return {
20- 'mask' : mask ,
21- 'ids_keep' : ids_keep ,
22- 'ids_mask' : ids_mask ,
23- 'ids_shuffle' : ids_shuffle ,
24- 'ids_restore' : ids_restore
25- }
15+ return ids_keep
2616
27- def start_route (self , x , mask_info ):
28- ids_shuffle = mask_info ['ids_shuffle' ]
29- num_keep = mask_info ['ids_keep' ].size (1 )
30- x_shuffled = x .gather (1 , ids_shuffle .unsqueeze (- 1 ).expand (- 1 , - 1 , x .size (2 )))
31- masked_x = x_shuffled [:, :num_keep , :]
32- return masked_x
17+ def start_route (self , x , ids_keep ):
18+ x_masked = x .gather (1 , ids_keep .unsqueeze (- 1 ).expand (- 1 , - 1 , x .size (2 )))
19+ return x_masked
3320
34- def end_route (self , masked_x , mask_info , original_x ):
35- batch_size , num_patches = mask_info ['mask' ].shape
36- num_keep = masked_x .size (1 )
37- dim = masked_x .size (2 )
38- device = masked_x .device
39- ids_restore = mask_info ['ids_restore' ]
40- x_unshuffled = torch .empty ((batch_size , num_patches , dim ), device = device )
41- x_unshuffled [:, :num_keep , :] = masked_x
42- x_shuffled = original_x .gather (1 , mask_info ['ids_shuffle' ].unsqueeze (- 1 ).expand (- 1 , - 1 , dim ))
43- x_unshuffled [:, num_keep :, :] = x_shuffled [:, num_keep :, :]
44- x_unmasked = x_unshuffled .gather (1 , ids_restore .unsqueeze (- 1 ).expand (- 1 , - 1 , dim ))
45- return x_unmasked
21+ def end_route (self , masked_x , ids_keep , original_x ):
22+ # (jerry) scatter is out-of-place, so this is safe
23+ x_unmasked = original_x .scatter (
24+ 1 , ids_keep .unsqueeze (- 1 ).expand (- 1 , - 1 , original_x .size (2 )), masked_x
25+ )
26+ return x_unmasked
0 commit comments