Pytorch Model Ensembling

import torch
import torch.nn as nn
import torch.nn.functional as F

class SAM(nn.Module):
def init(self, bias=False):
super(SAM, self).init()
self.bias = bias
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3, dilation=1, bias=self.bias)

def forward(self, x):
    max = torch.max(x, 1)[0].unsqueeze(1)
    avg = torch.mean(x, 1).unsqueeze(1)
    concat = torch.cat((max, avg), dim=1)
    output = self.conv(concat)
    output = output * x
    return output

class CAM(nn.Module):
def init(self, channels, r):
super(CAM, self).init()
self.channels = channels
self.r = r
self.linear = nn.Sequential(
nn.Linear(in_features=self.channels, out_features=self.channels // self.r, bias=True),
nn.ReLU(inplace=True),
nn.Linear(in_features=self.channels // self.r, out_features=self.channels, bias=True)
)

def forward(self, x):
    max = F.adaptive_max_pool2d(x, output_size=1)
    avg = F.adaptive_avg_pool2d(x, output_size=1)
    b, c, _, _ = x.size()
    linear_max = self.linear(max.view(b, c)).view(b, c, 1, 1)
    linear_avg = self.linear(avg.view(b, c)).view(b, c, 1, 1)
    output = linear_max + linear_avg
    output = torch.sigmoid(output) * x
    return output

class CBAM(nn.Module):
def init(self, channels, r):
super(CBAM, self).init()
self.channels = channels
self.r = r
self.sam = SAM(bias=False)
self.cam = CAM(channels=self.channels, r=self.r)

def forward(self, x):
    output = self.cam(x)
    output = self.sam(output)
    return output + x

def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)

class Encoder(nn.Module):
def init(self, in_channels, out_channels):
super(Encoder, self).init()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)

def forward(self, x):
    return self.encoder(x)

class Decoder(nn.Module):
def init(self, in_channels, out_channels):
super(Decoder, self).init()
self.decoder = nn.Sequential(
nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)

def forward(self, x):
    return self.decoder(x)

class VGG19_CBAM(nn.Module):
def init(self, in_channels, out_channels):
super(VGG19_CBAM, self).init()
self.in_channels = in_channels
self.out_channels = out_channels

    self.encoder = Encoder(in_channels=self.in_channels, out_channels=64)

    self.conv_block1 = nn.Sequential(
        conv_block(64, 64),
        CBAM(64, r=4),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

    self.conv_block2 = nn.Sequential(
        conv_block(64, 128),
        CBAM(128, r=4),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

    self.conv_block3 = nn.Sequential(
        conv_block(128, 256),
        *[conv_block(256, 256) for _ in range(2)],
        CBAM(256, r=4),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

    self.conv_block4 = nn.Sequential(
        conv_block(256, 512),
        *[conv_block(512, 512) for _ in range(2)],
        CBAM(512, r=4),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

    self.conv_block5 = nn.Sequential(
        *[conv_block(512, 512) for _ in range(3)],
        CBAM(512, r=4),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

    # Define decoder layers
    self.decoder1 = Decoder(in_channels=512, out_channels=256)
    self.decoder2 = Decoder(in_channels=256, out_channels=128)
    self.decoder3 = Decoder(in_channels=128, out_channels=64)
    self.decoder4 = Decoder(in_channels=64, out_channels=in_channels)  # Output channels should match input channels

    self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(7, 7))

    # Adjust dimensions based on the actual flattened size
    self.linear1 = nn.Sequential(
        nn.Linear(in_features=3 * 7 * 7, out_features=4096, bias=True),
        nn.Dropout(0.5),
        nn.ReLU()
    )
    self.linear2 = nn.Sequential(
        nn.Linear(in_features=4096, out_features=4096, bias=True),
        nn.Dropout(0.5),
        nn.ReLU()
    )
    self.linear3 = nn.Linear(in_features=4096, out_features=self.out_channels, bias=True)

def forward(self, x):
    x = self.encoder(x)
    x = self.conv_block1(x)
    x = self.conv_block2(x)
    x = self.conv_block3(x)
    x = self.conv_block4(x)
    x = self.conv_block5(x)

    # Add decoding layers
    x = self.decoder1(x)
    x = self.decoder2(x)
    x = self.decoder3(x)
    x = self.decoder4(x)

    # Print shape of x before flattening
    # print(f"Shape before avg_pool: {x.shape}")

    x = self.avg_pool(x)
    
    # Print shape of x after avg_pool
    # print(f"Shape after avg_pool: {x.shape}")

    # Flatten tensor
    x = x.view(x.shape[0], -1)

    # Print shape before passing to linear layers
    # print(f"Shape before linear layers: {x.shape}")

    x = self.linear1(x)
    x = self.linear2(x)
    x = self.linear3(x)
    return x

import torch
import torch.nn as nn
import torch.nn.functional as F

class CustomLoss(nn.Module):
def init(self, alpha=None, gamma=2.0, ce_weight=0.5, focal_weight=0.3, mcc_weight=0.1, f1_weight=0.1, epsilon=1e-7, label_smoothing=0.1):
super(CustomLoss, self).init()
self.gamma = gamma
self.ce_weight = ce_weight
self.focal_weight = focal_weight
self.mcc_weight = mcc_weight
self.f1_weight = f1_weight
self.epsilon = epsilon
self.label_smoothing = label_smoothing

    if alpha is None:
        self.alpha = torch.ones(6)
    else:
        self.alpha = torch.tensor(alpha, dtype=torch.float32)

def forward(self, y_pred, labels, weights=None):
    y_true = F.one_hot(labels, num_classes=y_pred.size(1)).float()
    y_true = y_true * (1 - self.label_smoothing) + self.label_smoothing / y_pred.size(1)
    
    if self.alpha is not None:
        self.alpha = self.alpha.to(y_pred.device)
    
    y_pred = torch.clamp(F.softmax(y_pred, dim=1), self.epsilon, 1.0 - self.epsilon)
    
    ce_loss = -torch.sum(y_true * torch.log(y_pred), dim=1)
    
    focal_loss = -torch.sum(self.alpha * (1 - y_pred) ** self.gamma * y_true * torch.log(y_pred), dim=1)
    
    if weights is None:
        weights = torch.ones(y_true.shape[0], dtype=torch.float32).to(y_pred.device)
    else:
        weights = weights.to(y_pred.device)
    
    tp = torch.sum(weights[:, None] * y_true * y_pred, dim=0)
    tn = torch.sum(weights[:, None] * (1 - y_true) * (1 - y_pred), dim=0)
    fp = torch.sum(weights[:, None] * (1 - y_true) * y_pred, dim=0)
    fn = torch.sum(weights[:, None] * y_true * (1 - y_pred), dim=0)
    
    denominator = torch.sqrt((tp + fp + self.epsilon) * (tp + fn + self.epsilon) * (tn + fp + self.epsilon) * (tn + fn + self.epsilon))
    
    numerator = tp * tn - fp * fn
    mcc_loss = 1.0 - torch.sum(numerator / (denominator + self.epsilon))
    
    precision = tp / (tp + fp + self.epsilon)
    recall = tp / (tp + fn + self.epsilon)
    f1_score = 2 * (precision * recall) / (precision + recall + self.epsilon)
    
    f1_loss = 1.0 - f1_score.mean()
    
    total_loss = (self.ce_weight * ce_loss.mean() +
                  self.focal_weight * focal_loss.mean() +
                  self.mcc_weight * mcc_loss +
                  self.f1_weight * f1_loss)
    
    return total_loss

model = VGG19_CBAM(in_channels=3, out_channels=6)
state_dict = torch.load(“VGG_CBAM/vgg_cbam_2024-08-07 04:54:01.475616.pth”)
model.load_state_dict(state_dict)

device = ‘cuda’ if torch.cuda.is_available() else ‘cpu’
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.005)
criterion = CustomLoss()
total_epochs = 40
model = model.to(device)

print(“Training Begin!”)
print()

best_accuracy = 0

for epoch in range(total_epochs):
model.train()

epoch_loss = 0
correct_predictions = 0
total_samples = 0
true_positives = torch.zeros(6).to(device)
false_negatives = torch.zeros(6).to(device)
false_positives = torch.zeros(6).to(device)
true_negatives = torch.zeros(6).to(device)

for step, (images, labels) in enumerate(train_loader):
    images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

    optimizer.zero_grad(set_to_none=True)
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    epoch_loss += loss.item() * images.size(0)

    _, predicted = torch.max(outputs, dim=1)
    correct_predictions += (predicted == labels).sum().item()
    total_samples += labels.size(0)

    for i in range(6):
        true_positives[i] += ((predicted == i) & (labels == i)).sum().item()
        false_negatives[i] += ((predicted != i) & (labels == i)).sum().item()
        false_positives[i] += ((predicted == i) & (labels != i)).sum().item()
        true_negatives[i] += ((predicted != i) & (labels != i)).sum().item()

    if (step + 1) % 1000 == 0:
        print('Epoch: [{}/{}] | Step: [{}/{}] | Loss: {:.4f}'.format(epoch + 1, total_epochs, step + 1, len(train_loader), loss.item()))

epoch_loss /= total_samples
train_acc = correct_predictions / total_samples

sensitivities = true_positives / (true_positives + false_negatives + 1e-7)
specificities = true_negatives / (true_negatives + false_positives + 1e-7)
train_sensitivity = sensitivities.mean()
train_specificity = specificities.mean()

print(f'Epoch: [{epoch + 1}/{total_epochs}] | Loss: {epoch_loss:.4f} | Accuracy: {train_acc:.4f} | Sensitivity: {train_sensitivity:.4f} | Specificity: {train_specificity:.4f}')

print(‘Train Finished!’)

this is model architecture, when we loaded the state, the predictions were coming the same for all data, but when we initialized the model by training it for a few steps, the accureacy returned, please help us understand the problem. same problem was coming on ensembling this with 2 other models