Skip to content

Commit 9890b77

Browse files
committed
add ultralight_vm_unet
1 parent 4625b0e commit 9890b77

2 files changed

Lines changed: 286 additions & 0 deletions

File tree

pymic/net/net2d/unet2d_vm_light.py

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import math
5+
import torch
6+
from torch import nn
7+
import torch.nn.functional as F
8+
9+
from timm.models.layers import trunc_normal_
10+
from mamba_ssm import Mamba
11+
12+
13+
class PVMLayer(nn.Module):
14+
def __init__(self, input_dim, output_dim, d_state = 16, d_conv = 4, expand = 2):
15+
super().__init__()
16+
self.input_dim = input_dim
17+
self.output_dim = output_dim
18+
self.norm = nn.LayerNorm(input_dim)
19+
self.mamba = Mamba(
20+
d_model=input_dim//4, # Model dimension d_model
21+
d_state=d_state, # SSM state expansion factor
22+
d_conv=d_conv, # Local convolution width
23+
expand=expand, # Block expansion factor
24+
)
25+
self.proj = nn.Linear(input_dim, output_dim)
26+
self.skip_scale= nn.Parameter(torch.ones(1))
27+
28+
def forward(self, x):
29+
if x.dtype == torch.float16:
30+
x = x.type(torch.float32)
31+
B, C = x.shape[:2]
32+
assert C == self.input_dim
33+
n_tokens = x.shape[2:].numel()
34+
img_dims = x.shape[2:]
35+
x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
36+
x_norm = self.norm(x_flat)
37+
38+
x1, x2, x3, x4 = torch.chunk(x_norm, 4, dim=2)
39+
x_mamba1 = self.mamba(x1) + self.skip_scale * x1
40+
x_mamba2 = self.mamba(x2) + self.skip_scale * x2
41+
x_mamba3 = self.mamba(x3) + self.skip_scale * x3
42+
x_mamba4 = self.mamba(x4) + self.skip_scale * x4
43+
x_mamba = torch.cat([x_mamba1, x_mamba2,x_mamba3,x_mamba4], dim=2)
44+
45+
x_mamba = self.norm(x_mamba)
46+
x_mamba = self.proj(x_mamba)
47+
out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims)
48+
return out
49+
50+
51+
class Channel_Att_Bridge(nn.Module):
52+
def __init__(self, c_list, split_att='fc'):
53+
super().__init__()
54+
c_list_sum = sum(c_list) - c_list[-1]
55+
self.split_att = split_att
56+
self.avgpool = nn.AdaptiveAvgPool2d(1)
57+
self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
58+
self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1)
59+
self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1)
60+
self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1)
61+
self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1)
62+
self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1)
63+
self.sigmoid = nn.Sigmoid()
64+
65+
def forward(self, t1, t2, t3, t4, t5):
66+
att = torch.cat((self.avgpool(t1),
67+
self.avgpool(t2),
68+
self.avgpool(t3),
69+
self.avgpool(t4),
70+
self.avgpool(t5)), dim=1)
71+
att = self.get_all_att(att.squeeze(-1).transpose(-1, -2))
72+
if self.split_att != 'fc':
73+
att = att.transpose(-1, -2)
74+
att1 = self.sigmoid(self.att1(att))
75+
att2 = self.sigmoid(self.att2(att))
76+
att3 = self.sigmoid(self.att3(att))
77+
att4 = self.sigmoid(self.att4(att))
78+
att5 = self.sigmoid(self.att5(att))
79+
if self.split_att == 'fc':
80+
att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1)
81+
att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2)
82+
att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3)
83+
att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4)
84+
att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5)
85+
else:
86+
att1 = att1.unsqueeze(-1).expand_as(t1)
87+
att2 = att2.unsqueeze(-1).expand_as(t2)
88+
att3 = att3.unsqueeze(-1).expand_as(t3)
89+
att4 = att4.unsqueeze(-1).expand_as(t4)
90+
att5 = att5.unsqueeze(-1).expand_as(t5)
91+
92+
return att1, att2, att3, att4, att5
93+
94+
95+
class Spatial_Att_Bridge(nn.Module):
96+
def __init__(self):
97+
super().__init__()
98+
self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3),
99+
nn.Sigmoid())
100+
101+
def forward(self, t1, t2, t3, t4, t5):
102+
t_list = [t1, t2, t3, t4, t5]
103+
att_list = []
104+
for t in t_list:
105+
avg_out = torch.mean(t, dim=1, keepdim=True)
106+
max_out, _ = torch.max(t, dim=1, keepdim=True)
107+
att = torch.cat([avg_out, max_out], dim=1)
108+
att = self.shared_conv2d(att)
109+
att_list.append(att)
110+
return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4]
111+
112+
113+
class SC_Att_Bridge(nn.Module):
114+
def __init__(self, c_list, split_att='fc'):
115+
super().__init__()
116+
117+
self.catt = Channel_Att_Bridge(c_list, split_att=split_att)
118+
self.satt = Spatial_Att_Bridge()
119+
120+
def forward(self, t1, t2, t3, t4, t5):
121+
r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5
122+
123+
satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5)
124+
t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5
125+
126+
r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5
127+
t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5
128+
129+
catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5)
130+
t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5
131+
132+
return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_
133+
134+
135+
class UltraLight_VM_UNet(nn.Module):
136+
def __init__(self, params):
137+
"""
138+
UltraLight_VM_UNet that is a lightweight model using CNN and Mamba.
139+
140+
* Reference: Renkai Wu, Yinghao Liu, Pengchen Liang, Qing Chang.
141+
UltraLight VM-UNet: Parallel Vision Mamba Significantly Reduces Parameters for Skin Lesion Segmentation.
142+
arxiv 2403.20035, 2024.
143+
144+
The implementation is based on the code at:
145+
https://github.com/wurenkai/UltraLight-VM-UNet.
146+
147+
The parameters for the backbone should be given in the `params` dictionary.
148+
149+
:param in_chns: (int) Input channel number.
150+
:param class_num: (int) The class number for segmentation task.
151+
:param feature_chns: (list) Feature channel for each resolution level.
152+
The length should be 6, by default it is [8, 16, 24, 32, 48, 64].
153+
:param bridge: (int) If the bridge based on spatial and channel attentions is used or not.
154+
By default it is True.
155+
"""
156+
super(UltraLight_VM_UNet, self).__init__()
157+
158+
input_channels = params['in_chns']
159+
num_classes = params['class_num']
160+
c_list = params.get('feature_chns', [8, 16, 24, 32, 48, 64])
161+
self.bridge = params.get('bridge', True)
162+
split_att = 'fc'
163+
# def __init__(self, num_classes=1, input_channels=3, c_list=[8,16,24,32,48,64],
164+
# split_att='fc', bridge=True):
165+
# super().__init__()
166+
# self.bridge = bridge
167+
168+
self.encoder1 = nn.Sequential(
169+
nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1),
170+
)
171+
self.encoder2 =nn.Sequential(
172+
nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1),
173+
)
174+
self.encoder3 = nn.Sequential(
175+
nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1),
176+
)
177+
self.encoder4 = nn.Sequential(
178+
PVMLayer(input_dim=c_list[2], output_dim=c_list[3])
179+
)
180+
self.encoder5 = nn.Sequential(
181+
PVMLayer(input_dim=c_list[3], output_dim=c_list[4])
182+
)
183+
self.encoder6 = nn.Sequential(
184+
PVMLayer(input_dim=c_list[4], output_dim=c_list[5])
185+
)
186+
187+
if self.bridge:
188+
self.scab = SC_Att_Bridge(c_list, split_att)
189+
print('SC_Att_Bridge was used')
190+
191+
self.decoder1 = nn.Sequential(
192+
PVMLayer(input_dim=c_list[5], output_dim=c_list[4])
193+
)
194+
self.decoder2 = nn.Sequential(
195+
PVMLayer(input_dim=c_list[4], output_dim=c_list[3])
196+
)
197+
self.decoder3 = nn.Sequential(
198+
PVMLayer(input_dim=c_list[3], output_dim=c_list[2])
199+
)
200+
self.decoder4 = nn.Sequential(
201+
nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1),
202+
)
203+
self.decoder5 = nn.Sequential(
204+
nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1),
205+
)
206+
self.ebn1 = nn.GroupNorm(4, c_list[0])
207+
self.ebn2 = nn.GroupNorm(4, c_list[1])
208+
self.ebn3 = nn.GroupNorm(4, c_list[2])
209+
self.ebn4 = nn.GroupNorm(4, c_list[3])
210+
self.ebn5 = nn.GroupNorm(4, c_list[4])
211+
self.dbn1 = nn.GroupNorm(4, c_list[4])
212+
self.dbn2 = nn.GroupNorm(4, c_list[3])
213+
self.dbn3 = nn.GroupNorm(4, c_list[2])
214+
self.dbn4 = nn.GroupNorm(4, c_list[1])
215+
self.dbn5 = nn.GroupNorm(4, c_list[0])
216+
217+
self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1)
218+
219+
self.apply(self._init_weights)
220+
221+
def _init_weights(self, m):
222+
if isinstance(m, nn.Linear):
223+
trunc_normal_(m.weight, std=.02)
224+
if isinstance(m, nn.Linear) and m.bias is not None:
225+
nn.init.constant_(m.bias, 0)
226+
elif isinstance(m, nn.Conv1d):
227+
n = m.kernel_size[0] * m.out_channels
228+
m.weight.data.normal_(0, math.sqrt(2. / n))
229+
elif isinstance(m, nn.Conv2d):
230+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
231+
fan_out //= m.groups
232+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
233+
if m.bias is not None:
234+
m.bias.data.zero_()
235+
236+
def forward(self, x):
237+
x_shape = list(x.shape)
238+
if(len(x_shape) == 5):
239+
[N, C, D, H, W] = x_shape
240+
new_shape = [N*D, C, H, W]
241+
x = torch.transpose(x, 1, 2)
242+
x = torch.reshape(x, new_shape)
243+
244+
out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2))
245+
t1 = out # b, c0, H/2, W/2
246+
247+
out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
248+
t2 = out # b, c1, H/4, W/4
249+
250+
out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
251+
t3 = out # b, c2, H/8, W/8
252+
253+
out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)),2,2))
254+
t4 = out # b, c3, H/16, W/16
255+
256+
out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)),2,2))
257+
t5 = out # b, c4, H/32, W/32
258+
259+
if self.bridge: t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5)
260+
261+
out = F.gelu(self.encoder6(out)) # b, c5, H/32, W/32
262+
263+
out5 = F.gelu(self.dbn1(self.decoder1(out))) # b, c4, H/32, W/32
264+
out5 = torch.add(out5, t5) # b, c4, H/32, W/32
265+
266+
out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c3, H/16, W/16
267+
out4 = torch.add(out4, t4) # b, c3, H/16, W/16
268+
269+
out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c2, H/8, W/8
270+
out3 = torch.add(out3, t3) # b, c2, H/8, W/8
271+
272+
out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c1, H/4, W/4
273+
out2 = torch.add(out2, t2) # b, c1, H/4, W/4
274+
275+
out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c0, H/2, W/2
276+
out1 = torch.add(out1, t1) # b, c0, H/2, W/2
277+
278+
out0 = F.interpolate(self.final(out1),scale_factor=(2,2),mode ='bilinear',align_corners=True) # b, num_class, H, W
279+
280+
if(len(x_shape) == 5):
281+
new_shape = [N, D] + list(out0.shape)[1:]
282+
out0 = torch.transpose(torch.reshape(out0, new_shape), 1, 2)
283+
return out0
284+

pymic/net/net_dict_seg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pymic.net.net2d.trans2d.transunet import TransUNet
2828
from pymic.net.net2d.trans2d.swinunet import SwinUNet
2929
from pymic.net.net2d.umamba import UMambaBot, UMambaEnc
30+
from pymic.net.net2d.unet2d_vm_light import UltraLight_VM_UNet
3031
from pymic.net.net3d.unet2d5 import UNet2D5
3132
from pymic.net.net3d.unet3d import UNet3D
3233
from pymic.net.net3d.grunet import GRUNet
@@ -65,6 +66,7 @@
6566
'UNet2D_ScSE': UNet2D_ScSE,
6667
'UMambaBot': UMambaBot,
6768
'UMambaEnc': UMambaEnc,
69+
'UltraLight_VM_UNet': UltraLight_VM_UNet,
6870
'TransUNet': TransUNet,
6971
'SwinUNet': SwinUNet,
7072
'UNet2D5': UNet2D5,

0 commit comments

Comments
 (0)