I was training a Auto Encoder Image Reconstruction model for finding the anomaly in the surveillance images, I used EfficientNetv2-s as Encoder and Se Block + Residual Block + Transpose Convolution Layers containing Custom Decoder.
But what is the model is unable to train, the loss is stagnant from the 5th epoch to 45th epoch…
Do i miss anything That i need to do while training Auto Encoder model
Visualization Report : Visualization report : Weights & Biases
Codes :
Encoder
class Encoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = models.efficientnet_b3(include_top=False, pretrained=False)
self.model.classifier = nn.Identity()
def forward(self, x):
return self.model(x)
# Encoder
class EfficientNetv2Encoder(pl.LightningModule):
def __init__(self):
super(EfficientNetv2Encoder, self).__init__()
self.file_path = utils.ROOT_PATH + '/weights/EfficientNetv2Encoder'
self.example_input_array = torch.rand(1, 3, 256, 256)
self.example_output_array = torch.rand(1, 1280)
self.save_hyperparameters()
self.model = Encoder()
try:
self.model = torch.load(utils.ROOT_PATH + '/weights/EfficientNetv2Encoder.pt')
print("Encoder Model Found")
except Exception as e:
self.model = models.efficientnet_v2_s(include_top=False, weights='EfficientNet_V2_S_Weights.DEFAULT')
self.model.classifier = nn.Identity()
torch.save(self.model, utils.ROOT_PATH + '/weights/EfficientNetv2Encoder.pt')
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log('train_loss', loss)
self.log('train_acc', acc)
return loss
Decoder:
class SEAttention(nn.Module):
def __init__(self, in_channels, reduction=16):
super(SEAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
# Initial representation
self.fc = nn.Linear(1280, 4*4*1024)
self.bn1d = nn.BatchNorm1d(4*4*1024)
self.gelu = nn.GELU()
# Decoder layers
self.conv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, output_padding=0)
self.bn1 = nn.BatchNorm2d(512)
self.relu1 = nn.ReLU()
self.conv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, output_padding=0)
self.bn2 = nn.BatchNorm2d(256)
self.relu2 = nn.ReLU()
self.conv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, output_padding=0)
self.bn3 = nn.BatchNorm2d(128)
self.relu3 = nn.ReLU()
self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=0)
self.bn4 = nn.BatchNorm2d(64)
self.relu4 = nn.ReLU()
self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, output_padding=0)
self.bn5 = nn.BatchNorm2d(32)
self.relu5 = nn.ReLU()
self.conv6 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1, output_padding=0)
self.bn6 = nn.BatchNorm2d(32)
self.relu6 = nn.ReLU()
# Residual blocks with SE attention
self.res2 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.Sigmoid(),
SEAttention(64),
nn.ReLU()
)
self.res1 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.Sigmoid(),
SEAttention(256),
nn.ReLU()
)
self.conv7 = nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.fc(x)
x = self.bn1d(x)
x = self.gelu(x)
x = x.view(-1, 1024, 4, 4)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.res1(x) + x
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu4(x)
x = self.res2(x) + x
x = self.conv5(x)
x = self.bn5(x)
x = self.relu5(x)
x = self.conv6(x)
x = self.bn6(x)
x = self.relu6(x)
x = self.conv7(x)
x = self.sigmoid(x)
return x
class EfficientNetv2Decoder(pl.LightningModule):
def __init__(self):
super(EfficientNetv2Decoder, self).__init__()
self.model = Decoder()
try:
self.model = torch.load(utils.ROOT_PATH + '/weights/EfficientNetv2DecoderLarge.pt')
print("Decoder Weights Found")
except Exception as e:
torch.save(self.model, utils.ROOT_PATH + '/weights/EfficientNetv2DecoderLarge.pt')
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.MSELoss()(y_hat, y)
self.log('train_loss', loss)
return loss
Auto Encoder:
class AutoEncoder(pl.LightningModule):
def __init__(self,
) -> None:
super(AutoEncoder, self).__init__()
self.example_input_array = torch.zeros(1, 3, 256, 256)
self.save_hyperparameters()
self.encoder = EfficientNetv2Encoder()
self.decoder = EfficientNetv2Decoder()
def forward(self, x):
try:
x = self.encoder(x)
x = self.decoder(x)
except Exception as e:
print(e, "Error!")
x = torch.rand(128, 3, 256, 256)
x = self.encoder(x)
x = self.decoder(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(1), x.size(2), x.size(3), x.size(4)).half()
y = y.view(y.size(1), y.size(2), y.size(3), y.size(4)).half()
batch_size = x.shape[0]
y_hat = self(x)
loss = nn.MSELoss()(y_hat, y)
self.log('train_loss', loss)
return {"loss" : loss}
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
return [optimizer], [scheduler]