I am re-constructing spatio-temporal cuboids of 3-dimensional size with width and height equal to 32 and depth equal to 20. I am using Conv3d layers in my autoencoder architecture.
So my input shape is 32x32x20, which I am reducing to size 2048, and then reconstructing it back to 32x32x20.
The MSE loss of the model has nice convergence even though the reconstruction is just noise.
My encoder architecture:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv3d-1 [-1, 64, 20, 32, 32] 5,248
AvgPool3d-2 [-1, 64, 10, 16, 16] 0
Conv3d-3 [-1, 128, 10, 16, 16] 221,312
AvgPool3d-4 [-1, 128, 5, 8, 8] 0
Conv3d-5 [-1, 256, 5, 8, 8] 884,992
AvgPool3d-6 [-1, 256, 2, 4, 4] 0
Conv3d-7 [-1, 512, 2, 4, 4] 3,539,456
AvgPool3d-8 [-1, 512, 1, 2, 2] 0
My decoder architecture:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Interpolate-1 [-1, 512, 2, 4, 4] 0
Conv3d-2 [-1, 256, 4, 4, 4] 3,539,200
Interpolate-3 [-1, 256, 8, 8, 8] 0
Conv3d-4 [-1, 128, 10, 8, 8] 884,864
Interpolate-5 [-1, 128, 20, 16, 16] 0
Conv3d-6 [-1, 64, 20, 16, 16] 221,248
Interpolate-7 [-1, 64, 40, 32, 32] 0
Conv3d-8 [-1, 3, 40, 32, 32] 5,187
Conv3d-9 [-1, 3, 20, 32, 32] 246
My reconstruction code:
from mpl_toolkits.axes_grid1 import ImageGrid
# recon_batch is the last batch of the autoencoder output.
# recon_batch has shape (batch_size, 3, 20, 32, 32)
recon_batch = recon_batch.permute(0, 2, 3, 4, 1) # new shape = (batch_size, 20, 32, 32, 3)
recon_batch = recon_batch.detach().cpu()
recon_batch_numpy = recon_batch.detach().cpu().numpy()
for k in range(3): # K=3 because I want to display 3 frames
fig = plt.figure(figsize=(4., 4.))
grid = ImageGrid(fig, 111,
nrows_ncols=(4, 4),
axes_pad=0.1,
)
images = [recon_batch_numpy[i][k] for i in range(16)] # 16 because Original Image is of size (128x128), when I make 32x32 patches, 16 sub frames are formed
for ax, im in zip(grid, images):
ax.imshow((rgb2gray(im) * 255).astype(np.uint8), cmap='gray', vmin=0, vmax=255)
name = str(num_epochs) + 'th_figure_' + str(k)
plt.savefig(name)
Original input image(actual cuboids of size 32x32x20, the image is reconstructed by taking the first frame for depth 0 of each such cuboid):
The reconstructed output image:
The loss starts around 10,000 and converges around 200.