Hello, something weird is happening when I try to train my small network. The output size after processing a batch is suddenly reduced to 12 (it is supposed to be 16) at batch #1518 of the first epoch.
RuntimeError: The size of tensor a (16) must match the size of tensor b (12) at non-singleton dimension 0
This error is produced when computing the loss. But it works perfectly fine for previous batches, which should be all identical (size 16)!
This is even more annoying because every time I execute and get this I have to restart the whole Colab thing, otherwise I get tha nnoyingly vague
RuntimeError: CUDA error: device-side assert triggered whichs ‘locks’ the notebook.
This is the code for the training:
mod1 = Module1().to(device) optimizer = optim.Adam(mod1.parameters(), lr=lr) for epoch in trange(epochs): for i, data in enumerate(trainloader): xt = data.to(device) # image patch xtk = data.to(device) # 'predicted' patch optimizer.zero_grad() zt = mod1(xt) ztk = mod1(xtk) loss = infoNCE(zt, ztk) # custom loss function loss.backward() optimizer.step()
My current network:
class Module1(nn.Module): def __init__(self): super(Module1, self).__init__() self.conv1 = nn.Conv2d(1, 32, 5, 1, 2) self.conv2 = nn.Conv2d(32, 64, 5, 1, 2) self.conv3 = nn.Conv2d(64, 128, 5, 1, 2) self.conv4 = nn.Conv2d(128, 128, 5, 1, 2) self.pool = nn.MaxPool2d(2, 2) self.fc = nn.Linear(128*1*1, 3) # 3 features (NORB) def forward(self, x): x = self.pool(F.relu(self.conv2(F.relu(self.conv1(x))))) x = self.pool(F.relu(self.conv4(F.relu(self.conv3(x))))) x = self.pool(self.pool(x)) x = x.view(x.size(0), -1) x = self.fc(x) return x
My custom loss:
def infoNCE(zt, ztk): ind = np.diag_indices(b_size) # batch size 16 mul = ztk @ torch.t(zt) num = mul[ind, ind] m, _ = torch.max(mul, axis=0) den = torch.log(torch.sum(torch.exp(mul - m), axis=0)) + m val = torch.mean(-num + den) return val
Note that everything was working for a different dataset (CIFAR-10) but now it breaks (NORB). The inputs for the network should have the same shape though, so I don’t know what is wrong.