Hello everyone, I’ve been training a model to interpolate a new intermediate frame(1.5) based on a previous frame(1) and a future frame(2).
For the past few days I’ve been strugling to increase the accuracy of my model and I would like some suggestions. At the moment my model has a validation loss of 19,37% (80,63% accuracy) and a train loss of 19,94% (80,06% accuracy) on the Vimeo90k dataset.
Previously I was training on my own dataset which contained 30k triplets for training and 3k triplets for validation from 50-ish videos in 720p. I trained the model util I hit diminishing returns ( validation loss decreased in ~00,02% every 4-5 epochs) and the train loss didn’t seem to go down by any significance either. In other words, the model was improving painfully slow. Because of this, I’ve moved to the Vimeo90k dataset which is a very diverse dataset that contains ~50k triplets for training and ~3.7k for validation, however i can’t help but notice that I’m also starting to get diminishing returns in the vimeo90k dataset pretty early into the training cycle (around epoch 20/100).
My model is structured as the following:
class FrameInterpolationModel(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(16, 64, kernel_size=7, padding=3, stride=1),
nn.PReLU(),
nn.MaxPool2d(2),
nn.PReLU(),
nn.Conv2d(64, 128, kernel_size=5, padding=2, stride=1),
nn.PReLU(),
nn.MaxPool2d(2),
nn.PReLU(),
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=1),
nn.PReLU(),
)
self.decoder = nn.Sequential(
nn.Conv2d(256, 128, kernel_size=3, padding=1, stride=1),
nn.PReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
nn.PReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=7, stride=2, padding=3, output_padding=1),
nn.PReLU(),
nn.Conv2d(32, 3, kernel_size=3, padding=1, stride=1),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
Currently I’m using 3 different loss functions with the following weights:
Perceptual Loss (VGG) - To extract features (weight = 0.6)
SSIM - This is to capture the general structure of the frame, from what I’ve tested this loss function has a higher capability of teaching the model to make images with much less ghosting compared to MSE (weight = 0.3)
L1Loss - I’ve added this one to help SSIM with colors and contrast as SSIM alone seemed to struggle to get colors right (weight = 0.1).
I’m also using 2 techniques that are commonly used on papers to potentially help the model to achieve a higher accuracy:
Optical flow - I’m using LiteFlowNet for this. Essentially I’m calculating the flows of f1->f2 and f3->f2 and feeding them to the model on the first conv2d channel.
Frame Warping - Based on flow_f1->f2 and flow_f3->f2 I’m warping both frames based on the flows trough the following code:
def warp_image(frame, flow):
device = frame.device
C, H, W = frame.shape
frame = frame.unsqueeze(0)
flow = flow.unsqueeze(0)
y_base, x_base = torch.meshgrid(
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='ij'
)
x_base = x_base.float()
y_base = y_base.float()
flow_y = flow[:, 0, :, :]
flow_x = flow[:, 1, :, :]
new_y = y_base + flow_y[0]
new_x = x_base + flow_x[0]
new_y = 2.0 * (new_y / (H - 1)) - 1.0
new_x = 2.0 * (new_x / (W - 1)) - 1.0
grid = torch.stack((new_x, new_y), dim=-1).unsqueeze(0)
warped = F.grid_sample(frame, grid, mode='bilinear', padding_mode='border', align_corners=True)
return warped.squeeze(0)
So, basically I’m feeding the following frames into the encoder: frame1,frame2,frame3, flow_1->2, flow_3->2, warped1 and warped2 which sums 16 in the first encoder layer.
With that being said, I’ve noticed 2 main issues that are likely affecting the overall accuracy of the model. The first issue is, the model still struggles to interpolate a few movements which generates ghosting, I’ve implemented Optical Flow and Frame Warping in an attempt to mitigate the first issue, however this leads to the second issue. Sometimes the Frame Warping has inaccuracies that are prevalent which directly affects the quality of the interpolated frame.
Here are some examples of issues 1 and 2:
I’m also using adam as my optimizer with the following params:
optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=1e-5)
and my learning rate has been lr=3e-5 since the beginning
The entire code can be found on my githubGithub in “treinamento_v2_5_per_and_flownet.py” (please excuse the fact that everything is in a single file for now)
I’m open to any suggestions. My goal is to at least reach 85% accuracy, however 90% would be ideal.
Also I’m sorry for not showing a graph with the val and train loss. I’m waiting until my model reaches at least 50 epochs on the vimeo90k dataset to plot the graph.