I am working with a localizer network followed by a regression network (essentially extending the spatial transformer network to work with 3d volumes).
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.localization = nn.Sequential(
nn.Conv3d(1, 8, kernel_size=7),
nn.MaxPool3d(2, stride=2),
nn.ReLU(True),
nn.Conv3d(8, 10*3*6, kernel_size=5),
nn.MaxPool3d(2, stride=2),
nn.ReLU(True)
)
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10*3*6, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 6)
)
# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], dtype=torch.double))
def forward(self, x):
xs = self.localization(x)
print('before squeeze', xs.shape)
xs = xs.view(-1, 10 * 3 * 6)
print('xs shape', xs.shape)
theta = self.fc_loc(xs)
print('theta shape', theta.shape)
theta = theta.view(-1,18)
return theta
model = Net().to(device)
print(model)
def train(epoch):
model.train()
for batch_idx, sample_batch in enumerate(dataloader):
data, target = sample_batched['volume'].to(device), sample_batched['label'].to(device)
print('the data size is ')
print(data.size())
print('the target size is ')
print(target.size())
optimizer.zero_grad()
output = model(data.float())
loss = F.nll_loss(output, target.long())
loss.backward()
optimizer.step()
if batch_idx % 2 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(dataloader.dataset),
100. * batch_idx / len(dataloader), loss.item()))
%% output:
the data size is
torch.Size([1, 1, 64, 64, 64])
the target size is
torch.Size([1, 18])
before squeeze torch.Size([1, 180, 12, 12, 12])
xs shape torch.Size([1728, 180])
theta shape torch.Size([1728, 18])
I understand that for the model to work, my output and my target should be of the same shape. I was hoping to achieve this by having the last linear layer in the fc_loc model to have 18 output channels but it didn’t work.