I have a pytorch class. I am processing a tensor of shape[(256, 256, 256, 1)]. While building the model, this 4d gets in role. But when training, during which the batch_size gets added, how does the actions like permut() work.wouldnt it cause mismatch in the tensors? Also, if you take a look at this class, 4d tensor is what got maintained till the end, but at the end point, 5d tensors are expected. Can somebody explain, how batch size works in model building and model training.
class VoxelMorph1(nn.Module):
def init(self, input_shape=(32, 32, 1), optimizer=‘adam’, loss=None,
metrics=None, loss_weights=None):
super(VoxelMorph1, self).init()
in_channels = 1
out_channels = 3
input_shape = input_shape + (in_channels,)
self.moving = nn.Parameter(torch.randn(input_shape), requires_grad=True)
self.static = nn.Parameter(torch.randn(input_shape), requires_grad=True)
self.static = nn.Parameter(self.static.unsqueeze(0), requires_grad=True)
self.moving = nn.Parameter(self.moving.unsqueeze(0), requires_grad=True)
x_in = torch.cat([self.static, self.moving], dim=-1)
x_in = x_in.permute(3, 0,1,2)
# encoder
x1 = nn.Conv3d(in_channels=2, out_channels=16, kernel_size=3, stride=2, padding=1)
x1 = nn.LeakyReLU(negative_slope=0.2)(x1(x_in)) # 16
print("x1",x1.shape)
x2 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)
x2 = nn.LeakyReLU(negative_slope=0.2)(x2(x1)) # 8
x3 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
x3 = nn.LeakyReLU(negative_slope=0.2)(x3(x2)) # 4
x4 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
x4 = nn.LeakyReLU(negative_slope=0.2)(x4(x3)) # 2
#x4 = x4.permute(3, 0,1,2)
# decoder [32, 32, 32, 32, 8, 8]
x = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
x = nn.LeakyReLU(negative_slope=0.2)(x(x4))
x= torch.unsqueeze(x, dim=1)
x = nn.Upsample(scale_factor=2, mode='nearest')(x)
x= torch.squeeze(x, dim=1)
x3 = torch.squeeze(x3, dim=1)
#x = x.permute(0, 2,3,1)
xd1 = torch.cat([x, x3], dim=0) # 4
x = nn.Conv3d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)
x = nn.LeakyReLU(negative_slope=0.2)(x(xd1))
x= torch.unsqueeze(x, dim=1)
x = nn.Upsample(scale_factor=2, mode='nearest')(x) # 8
x= torch.squeeze(x, dim=1)
xd2 = torch.cat([x, x2], dim=0) # 8
x = nn.Conv3d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)
x = nn.LeakyReLU(negative_slope=0.2)(x(xd2))
x= torch.unsqueeze(x, dim=1)
x = nn.Upsample(scale_factor=2, mode='nearest')(x) # 16
x= torch.squeeze(x, dim=1)
xd3 = torch.cat([x, x1], dim=0) # 16
x = nn.Conv3d(in_channels=48, out_channels=32, kernel_size=3, stride=1, padding=1)
xd4 = nn.LeakyReLU(negative_slope=0.2)(x(xd3))
x = nn.Conv3d(in_channels=32, out_channels=8, kernel_size=3, stride=1, padding=1)
x = nn.LeakyReLU(negative_slope=0.2)(x(xd4)) # 16
x= torch.unsqueeze(x, dim=1)
x = nn.Upsample(scale_factor=2, mode='nearest')(x) # 32
x= torch.squeeze(x, dim=1)
xd5 = torch.cat([x, x_in], dim=0)
x = nn.Conv3d(in_channels=10, out_channels=8, kernel_size=3, stride=1, padding=1)
x = nn.LeakyReLU(negative_slope=0.2)(x(xd5)) # 32
#torch.nn.init.normal_(nn.conv3d.weight, mean=0.0, std=1e-5)
con3d = nn.Conv3d(in_channels=8, out_channels=out_channels, kernel_size=3, stride=1,
padding=1, bias=True)
torch.nn.init.normal_(con3d.weight, mean=0.0, std=1e-5)
deformation = con3d(x)
print("deformation",deformation.shape)
nb, nd, nh, nw, nc = deformation.shape