Skip to content

Commit 8586a74

Browse files
authored
Merge pull request #9 from jxiong21029/simple-routing
Simplify router implementation
2 parents b1f5cc8 + 43a15dc commit 8586a74

2 files changed

Lines changed: 13 additions & 32 deletions

File tree

dit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -
243243
for idx, block in enumerate(self.blocks):
244244
if use_routing and idx == self.routes[route_count]['start_layer_idx']:
245245
x_D_last = x.clone()
246-
mask_info = self.router.get_mask(x, mask_ratio=self.routes[route_count]['selection_ratio'] if overwrite_selection_ratio is None else overwrite_selection_ratio)
247-
x = self.router.start_route(x, mask_info)
246+
ids_keep = self.router.get_mask(x, selection_rate=self.routes[route_count]['selection_ratio'] if overwrite_selection_ratio is None else overwrite_selection_ratio)
247+
x = self.router.start_route(x, ids_keep)
248248

249249
if fp32_next:
250250
with torch.amp.autocast(device_type="cuda", enabled=False):
@@ -254,7 +254,7 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -
254254
x = block(x, c)
255255

256256
if use_routing and idx == self.routes[route_count]['end_layer_idx']:
257-
x = self.router.end_route(x, mask_info, original_x=x_D_last)
257+
x = self.router.end_route(x, ids_keep, original_x=x_D_last)
258258
fp32_next = True
259259
if route_count < len(self.routes) - 1:
260260
route_count += 1

routing_module.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,16 @@ def get_mask(self, x, selection_rate=0.0):
1111
num_keep = num_patches - num_mask
1212
noise_random = torch.rand(batch_size, num_patches, device=device)
1313
ids_shuffle = torch.argsort(noise_random, dim=1)
14-
ids_restore = torch.argsort(ids_shuffle, dim=1)
1514
ids_keep = ids_shuffle[:, :num_keep]
16-
ids_mask = ids_shuffle[:, num_keep:]
17-
mask = torch.ones((batch_size, num_patches), device=device, dtype=torch.bool)
18-
mask.scatter_(1, ids_keep, False)
19-
return {
20-
'mask': mask,
21-
'ids_keep': ids_keep,
22-
'ids_mask': ids_mask,
23-
'ids_shuffle': ids_shuffle,
24-
'ids_restore': ids_restore
25-
}
15+
return ids_keep
2616

27-
def start_route(self, x, mask_info):
28-
ids_shuffle = mask_info['ids_shuffle']
29-
num_keep = mask_info['ids_keep'].size(1)
30-
x_shuffled = x.gather(1, ids_shuffle.unsqueeze(-1).expand(-1, -1, x.size(2)))
31-
masked_x = x_shuffled[:, :num_keep, :]
32-
return masked_x
17+
def start_route(self, x, ids_keep):
18+
x_masked = x.gather(1, ids_keep.unsqueeze(-1).expand(-1, -1, x.size(2)))
19+
return x_masked
3320

34-
def end_route(self, masked_x, mask_info, original_x):
35-
batch_size, num_patches = mask_info['mask'].shape
36-
num_keep = masked_x.size(1)
37-
dim = masked_x.size(2)
38-
device = masked_x.device
39-
ids_restore = mask_info['ids_restore']
40-
x_unshuffled = torch.empty((batch_size, num_patches, dim), device=device)
41-
x_unshuffled[:, :num_keep, :] = masked_x
42-
x_shuffled = original_x.gather(1, mask_info['ids_shuffle'].unsqueeze(-1).expand(-1, -1, dim))
43-
x_unshuffled[:, num_keep:, :] = x_shuffled[:, num_keep:, :]
44-
x_unmasked = x_unshuffled.gather(1, ids_restore.unsqueeze(-1).expand(-1, -1, dim))
45-
return x_unmasked
21+
def end_route(self, masked_x, ids_keep, original_x):
22+
# (jerry) scatter is out-of-place, so this is safe
23+
x_unmasked = original_x.scatter(
24+
1, ids_keep.unsqueeze(-1).expand(-1, -1, original_x.size(2)), masked_x
25+
)
26+
return x_unmasked

0 commit comments

Comments
 (0)