Hello everyone,
I am working on a 3D DenseNet121 model, pre-trained on RadImageNet, for a multi-classification problem. In order to be able to use the 2D weights, I followed the process suggested in Conv3D, where the pre-trained 2D kernels are repeated along an axis and then normalized.
As an initial test, I froze all layers of the model except for the final fully connected (FC) layer, responsible for the classification (fc = nn.Linear(1024, 3) in this case).
Issue:
During training, the model performs as expected, and the output values remain within reasonable ranges. Additionally, the training loss decreases as expected. However, during validation with certain batches (Batch 1, 5, 27, 33, and 47), I observe abnormal logit values at the input of the FC layer.
For example, in Batch 47 of validation:
Shape of input to backbone torch.Size([4, 3, 160, 256, 256]) | Min: 0.0 Max: 1.0 Mean: 0.079 (Input image)
---
Shape after conv0: torch.Size([4, 64, 80, 128, 128]) | Min: -52.88 Max: 39.00 Mean: 0.129
---
Shape after norm0: torch.Size([4, 64, 80, 128, 128]) | Min: -38.135 Max: 38.438 Mean: 0.384
---
Shape after relu0: torch.Size([4, 64, 80, 128, 128]) | Min: 0.0 Max: 38.438 Mean: 0.638
---
Shape after pool0: torch.Size([4, 64, 40, 64, 64]) | Min: 0.0 Max: 38.43 Mean: 0.819
---
Shape after denseblock1: torch.Size([4, 256, 40, 64, 64]) | Min: -1750.29 Max: 1917.981 Mean: -11.4197
---
Shape after transition1: torch.Size([4, 128, 20, 32, 32]) | Min: -140.671 Max: 147.688 Mean: -2.203
---
Shape after denseblock2: torch.Size([4, 512, 20, 32, 32]) | Min: -934.444 Max: 898.603 Mean: -3.426
---
Shape after transition2: torch.Size([4, 256, 10, 16, 16]) | Min: -145.087 Max: 136.140 Mean: -1.081
---
Shape after denseblock3: torch.Size([4, 1024, 10, 16, 16]) | Min: -23304.38 Max: 162375.984 Mean: 18.142
---
Shape after transition3: torch.Size([4, 512, 5, 8, 8]) | Min: -47275996.0 Max: 52069472.0 Mean: 25939.146
---
Shape after denseblock4: torch.Size([4, 1024, 5, 8, 8]) | Min: -531637568.0 Max: 437007552.0 Mean: -103217.226
---
Shape after norm5: torch.Size([4, 1024, 5, 8, 8]) | Min: -10179955.0 Max: 8197048.0 Mean: -1067.137939453125
As you can see, the values in denseblock3 significantly increase, and I’m unsure why this happens.
I’ve checked the input images and confirmed they are normalized correctly (values in the range [0, 1]). I’ve also checked if the model is misclassifying new classes in those batches, but it has already handled those classes in previous batches, where the values were normal.
Here’s a brief snippet of the last layer of the denseblock3:
(denselayer24): _DenseLayer(
(norm1): BatchNorm3d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv3d(992, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv3d(128, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
)
And here’s a snippet of my conversion code:
class Classifier(nn.Module):
def __init__(self, num_class):
super().__init__()
self.fc = nn.Linear(1024, num_class)
def forward(self, x):
x = self.fc(x)
return x
class Backbone(torch.nn.Module):
def __init__(self, path):
super().__init__()
base_model = densenet121(weights=None)
encoder_layers = list(base_model.children())
self.backbone = nn.Sequential(*encoder_layers[:-1])
state_dict = torch.load(path)
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k[9:]] = v
print(self.backbone.load_state_dict(new_state_dict)) # <All keys matched successfully>
def forward(self, x):
features = self.backbone(x)
features = F.relu(features, inplace=True)
features = F.adaptive_avg_pool3d(features, output_size=1).view(features.size(0), -1)
return features
backbone = Backbone(radimagenet_pretrain_path)
classifier = Classifier(num_class=3)
model = nn.Sequential(backbone, classifier)
model_3d = Conv3dConverter(model, i3d_repeat_axis=-3)
Has anyone encountered similar issues? I suspect it might be related to BatchNorm3d, but I’m not sure. Any advice on what might be causing this behavior or suggestions on how to debug it would be greatly appreciated.
Thank you in advance for your help!