1+ import torch
2+ import torch .nn as nn
3+ from typing import List , Tuple , Dict
4+ from . import ConvWithNorms
5+
6+ SPLIT_BATCH_SIZE = 512
7+
8+ class MMHeadDecoder (nn .Module ):
9+
10+ def __init__ (self , pseudoimage_channels : int = 64 ):
11+ super ().__init__ ()
12+
13+ self .offset_encoder = nn .Linear (3 , 128 )
14+
15+ # FIXME: figure out how to set nheads and num_layers properly
16+ # ref: https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.html
17+ # https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoderLayer.html
18+ transform_decoder_layers = nn .TransformerDecoderLayer (d_model = 128 , nhead = 4 )
19+ self .pts_off_transformer = nn .TransformerDecoder (transform_decoder_layers , num_layers = 4 )
20+
21+ self .decoder = nn .Sequential (
22+ nn .Linear (pseudoimage_channels * 2 , 32 ), nn .GELU (),
23+ nn .Linear (32 , 3 ))
24+
25+ def forward_single (self , before_pseudoimage : torch .Tensor ,
26+ after_pseudoimage : torch .Tensor ,
27+ point_offsets : torch .Tensor ,
28+ voxel_coords : torch .Tensor ) -> torch .Tensor :
29+ voxel_coords = voxel_coords .long ()
30+ # assert (voxel_coords[:, 0] == 0).all(), "Z index must be 0"
31+
32+ # Voxel coords are Z, Y, X, and the pseudoimage is Channel, Y, X
33+ # I have confirmed via visualization that these coordinates are correct.
34+ after_voxel_vectors = after_pseudoimage [:, voxel_coords [:, 1 ],
35+ voxel_coords [:, 2 ]].T
36+ before_voxel_vectors = before_pseudoimage [:, voxel_coords [:, 1 ],
37+ voxel_coords [:, 2 ]].T
38+
39+ # [N, 64] [N, 64] -> [N, 128]
40+ concatenated_vectors = torch .cat ([before_voxel_vectors , after_voxel_vectors ], dim = 1 )
41+
42+ # [N, 128] [N, 128] -> [N, 1, 128]
43+ voxel_feature = concatenated_vectors .unsqueeze (1 )
44+ point_offsets_feature = self .offset_encoder (point_offsets ).unsqueeze (1 )
45+ concatenated_feature = torch .zeros_like (voxel_feature )
46+
47+ for spilt_range in range (0 , concatenated_feature .shape [0 ], SPLIT_BATCH_SIZE ):
48+ concatenated_feature [spilt_range :spilt_range + SPLIT_BATCH_SIZE ] = self .pts_off_transformer (
49+ voxel_feature [spilt_range :spilt_range + SPLIT_BATCH_SIZE ],
50+ point_offsets_feature [spilt_range :spilt_range + SPLIT_BATCH_SIZE ]
51+ )
52+
53+ flow = self .decoder (concatenated_feature .squeeze (1 ))
54+ return flow
55+
56+ def forward (
57+ self , before_pseudoimages : torch .Tensor ,
58+ after_pseudoimages : torch .Tensor ,
59+ voxelizer_infos : List [Dict [str ,
60+ torch .Tensor ]]) -> List [torch .Tensor ]:
61+
62+ flow_results = []
63+ for before_pseudoimage , after_pseudoimage , voxelizer_info in zip (
64+ before_pseudoimages , after_pseudoimages , voxelizer_infos ):
65+ point_offsets = voxelizer_info ["point_offsets" ]
66+ voxel_coords = voxelizer_info ["voxel_coords" ]
67+ flow = self .forward_single (before_pseudoimage , after_pseudoimage ,
68+ point_offsets , voxel_coords )
69+ flow_results .append (flow )
70+ return flow_results
71+
72+ class LinearDecoder (nn .Module ):
73+
74+ def __init__ (self , pseudoimage_channels : int = 64 ):
75+ super ().__init__ ()
76+
77+ self .offset_encoder = nn .Linear (3 , 128 )
78+
79+ self .decoder = nn .Sequential (
80+ nn .Linear (pseudoimage_channels * 4 , 32 ), nn .GELU (),
81+ nn .Linear (32 , 3 ))
82+
83+ def forward_single (self , before_pseudoimage : torch .Tensor ,
84+ after_pseudoimage : torch .Tensor ,
85+ point_offsets : torch .Tensor ,
86+ voxel_coords : torch .Tensor ) -> torch .Tensor :
87+ voxel_coords = voxel_coords .long ()
88+ # assert (voxel_coords[:, 0] == 0).all(), "Z index must be 0"
89+
90+ # Voxel coords are Z, Y, X, and the pseudoimage is Channel, Y, X
91+ # I have confirmed via visualization that these coordinates are correct.
92+ after_voxel_vectors = after_pseudoimage [:, voxel_coords [:, 1 ],
93+ voxel_coords [:, 2 ]].T
94+ before_voxel_vectors = before_pseudoimage [:, voxel_coords [:, 1 ],
95+ voxel_coords [:, 2 ]].T
96+
97+ # [N, 64] [N, 64] -> [N, 128]
98+ concatenated_vectors = torch .cat ([before_voxel_vectors , after_voxel_vectors ], dim = 1 )
99+
100+ # [N, 3] -> [N, 128]
101+ point_offsets_feature = self .offset_encoder (point_offsets )
102+
103+ flow = self .decoder (torch .cat ([concatenated_vectors , point_offsets_feature ], dim = 1 ))
104+ return flow
105+
106+ def forward (
107+ self , before_pseudoimages : torch .Tensor ,
108+ after_pseudoimages : torch .Tensor ,
109+ voxelizer_infos : List [Dict [str ,
110+ torch .Tensor ]]) -> List [torch .Tensor ]:
111+
112+ flow_results = []
113+ for before_pseudoimage , after_pseudoimage , voxelizer_info in zip (
114+ before_pseudoimages , after_pseudoimages , voxelizer_infos ):
115+ point_offsets = voxelizer_info ["point_offsets" ]
116+ voxel_coords = voxelizer_info ["voxel_coords" ]
117+ flow = self .forward_single (before_pseudoimage , after_pseudoimage ,
118+ point_offsets , voxel_coords )
119+ flow_results .append (flow )
120+ return flow_results
121+
122+ # from https://github.com/weiyithu/PV-RAFT/blob/main/model/update.py
123+ class ConvGRU (nn .Module ):
124+ def __init__ (self , input_dim = 64 , hidden_dim = 128 ):
125+ super (ConvGRU , self ).__init__ ()
126+ self .convz = nn .Conv1d (input_dim + hidden_dim , hidden_dim , 1 )
127+ self .convr = nn .Conv1d (input_dim + hidden_dim , hidden_dim , 1 )
128+ self .convq = nn .Conv1d (input_dim + hidden_dim , hidden_dim , 1 )
129+
130+ def forward (self , h , x ):
131+ hx = torch .cat ([h , x ], dim = 1 )
132+
133+ z = torch .sigmoid (self .convz (hx ))
134+ r = torch .sigmoid (self .convr (hx ))
135+ rh_x = torch .cat ([r * h , x ], dim = 1 )
136+ q = torch .tanh (self .convq (rh_x ))
137+
138+ h = (1 - z ) * h + z * q
139+ return h
140+
141+ class ConvGRUDecoder (nn .Module ):
142+
143+ def __init__ (self , pseudoimage_channels : int = 64 , num_iters : int = 4 ):
144+ super ().__init__ ()
145+
146+ self .offset_encoder = nn .Linear (3 , pseudoimage_channels )
147+
148+ # NOTE: voxel feature is hidden input, point offset is input, check paper's Fig. 3
149+ self .gru = ConvGRU (input_dim = pseudoimage_channels , hidden_dim = pseudoimage_channels * 2 )
150+
151+ self .decoder = nn .Sequential (
152+ nn .Linear (pseudoimage_channels * 3 , pseudoimage_channels // 2 ), nn .GELU (),
153+ nn .Linear (pseudoimage_channels // 2 , 3 ))
154+ self .num_iters = num_iters
155+
156+ def forward_single (self , before_pseudoimage : torch .Tensor ,
157+ after_pseudoimage : torch .Tensor ,
158+ point_offsets : torch .Tensor ,
159+ voxel_coords : torch .Tensor ) -> torch .Tensor :
160+ voxel_coords = voxel_coords .long ()
161+ # assert (voxel_coords[:, 0] == 0).all(), "Z index must be 0"
162+
163+ # Voxel coords are Z, Y, X, and the pseudoimage is Channel, Y, X
164+ # I have confirmed via visualization that these coordinates are correct.
165+ after_voxel_vectors = after_pseudoimage [:, voxel_coords [:, 1 ],
166+ voxel_coords [:, 2 ]].T
167+ before_voxel_vectors = before_pseudoimage [:, voxel_coords [:, 1 ],
168+ voxel_coords [:, 2 ]].T
169+
170+ # [N, 64] [N, 64] -> [N, 128]
171+ concatenated_vectors = torch .cat ([before_voxel_vectors , after_voxel_vectors ], dim = 1 )
172+
173+ # [N, 3] -> [N, 64]
174+ point_offsets_feature = self .offset_encoder (point_offsets )
175+
176+ # [N, 128] -> [N, 128, 1]
177+ concatenated_vectors = concatenated_vectors .unsqueeze (2 )
178+
179+ for itr in range (self .num_iters ):
180+ concatenated_vectors = self .gru (concatenated_vectors , point_offsets_feature .unsqueeze (2 ))
181+
182+ flow = self .decoder (torch .cat ([concatenated_vectors .squeeze (2 ), point_offsets_feature ], dim = 1 ))
183+ return flow
184+
185+ def forward (
186+ self , before_pseudoimages : torch .Tensor ,
187+ after_pseudoimages : torch .Tensor ,
188+ voxelizer_infos : List [Dict [str ,
189+ torch .Tensor ]]) -> List [torch .Tensor ]:
190+
191+ flow_results = []
192+ for before_pseudoimage , after_pseudoimage , voxelizer_info in zip (
193+ before_pseudoimages , after_pseudoimages , voxelizer_infos ):
194+ point_offsets = voxelizer_info ["point_offsets" ]
195+ voxel_coords = voxelizer_info ["voxel_coords" ]
196+ flow = self .forward_single (before_pseudoimage , after_pseudoimage ,
197+ point_offsets , voxel_coords )
198+ flow_results .append (flow )
199+ return flow_results
200+
201+
202+ class ConvWithNorms (nn .Module ):
203+
204+ def __init__ (self , in_num_channels : int , out_num_channels : int ,
205+ kernel_size : int , stride : int , padding : int ):
206+ super ().__init__ ()
207+ self .conv = nn .Conv2d (in_num_channels , out_num_channels , kernel_size ,
208+ stride , padding )
209+ self .batchnorm = nn .BatchNorm2d (out_num_channels )
210+ self .nonlinearity = nn .GELU ()
211+
212+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
213+ conv_res = self .conv (x )
214+ if conv_res .shape [2 ] == 1 and conv_res .shape [3 ] == 1 :
215+ # This is a hack to get around the fact that batchnorm doesn't support
216+ # 1x1 convolutions
217+ batchnorm_res = conv_res
218+ else :
219+ batchnorm_res = self .batchnorm (conv_res )
220+ return self .nonlinearity (batchnorm_res )
0 commit comments