There is a criticle bug in tracking module. The datas are not aligned in the self-attention module of tracking model, which would lead to the failure of training when the batch size is bigger than one on each GPU.
At line 194 of file "projects/mmdet3d_plugin/uniad/modules/temporal_self_attention.py",
query = torch.cat([value[:bs], query], -1)
"value" is composed of two parts, prev_bev and bev_query, which is passed by severel modules and originally generated from BEVFormerEncoder(projects/mmdet3d_plugin/uniad/modules/encoder.py, line 202). Here is its definition:
prev_bev = torch.stack(
[prev_bev, bev_query], 1).reshape(bs*2, len_bev, -1)
As it's stacked on dimension 1 and then reshaped, it's formed like this: [num1_prev, num1_curr, num2_prev, num2_curr, num3_prev, num3_curr, ……]
If we set the batch size 4, then the query would be like this: [num1_prev, num1_curr, num2_prev, num2_prev, num2_curr, num1_curr, num2_curr, num3_curr, num3_curr], which is obviously a mistake.
As I don't have the gpu environment to debug the project in practice, I can only make the inference from code reading. If I were wrong, you can just delete the issue.
There is a criticle bug in tracking module. The datas are not aligned in the self-attention module of tracking model, which would lead to the failure of training when the batch size is bigger than one on each GPU.
At line 194 of file "projects/mmdet3d_plugin/uniad/modules/temporal_self_attention.py",
query = torch.cat([value[:bs], query], -1)
"value" is composed of two parts, prev_bev and bev_query, which is passed by severel modules and originally generated from BEVFormerEncoder(projects/mmdet3d_plugin/uniad/modules/encoder.py, line 202). Here is its definition:
prev_bev = torch.stack(
[prev_bev, bev_query], 1).reshape(bs*2, len_bev, -1)
As it's stacked on dimension 1 and then reshaped, it's formed like this: [num1_prev, num1_curr, num2_prev, num2_curr, num3_prev, num3_curr, ……]
If we set the batch size 4, then the query would be like this: [num1_prev, num1_curr, num2_prev, num2_prev, num2_curr, num1_curr, num2_curr, num3_curr, num3_curr], which is obviously a mistake.
As I don't have the gpu environment to debug the project in practice, I can only make the inference from code reading. If I were wrong, you can just delete the issue.