I’m trying to implement a variant of capsule network where the matrix multiplication is replaced by element-wise multiplication with a vector. During training (mostly after the first backpropagation) the outputs become nan. I tried using gradient clipping, but it didn’ work. I’m working with MNIST dataset and I’m normalizing it before training.
## Training data loading and normalizing
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST('/mnist/', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,shuffle=True, num_workers=2, batch_size = 10)
## Model Architecture
class ConvLayer(nn.Module):
def __init__(self, in_channels=1, out_channels=256, kernel_size=9):
super(ConvLayer, self).__init__()
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1
)
def forward(self, x):
return F.relu(self.conv(x))
class PrimaryCaps(nn.Module):
def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, num_routes=32 * 6 * 6):
super(PrimaryCaps, self).__init__()
self.num_routes = num_routes
self.capsules = nn.ModuleList([
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0)
for _ in range(num_capsules)])
def forward(self, x):
u = [capsule(x) for capsule in self.capsules]
u = torch.stack(u, dim=1)
u = u.view(x.size(0), self.num_routes, -1)
return self.squash(u)
def squash(self, input_tensor):
squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
return output_tensor
class DigitCaps(nn.Module):
def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
super(DigitCaps, self).__init__()
self.in_channels = in_channels
self.num_routes = num_routes
self.num_capsules = num_capsules
self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, in_channels, 1))
def forward(self, x):
batch_size = x.size(0)
x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
W = torch.cat([self.W] * batch_size, dim=0)
u_hat = W * x
b_ij = torch.zeros(1, self.num_routes, self.num_capsules, 1)
if USE_CUDA:
b_ij = b_ij.cuda()
num_iterations = 3
for iteration in range(num_iterations):
c_ij = F.softmax(b_ij, dim=1)
c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
v_j = self.squash(s_j)
if iteration < num_iterations - 1:
a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)
return v_j.squeeze(1)
def squash(self, input_tensor):
squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
return output_tensor
class CapsNet(nn.Module):
def __init__(self, config=None):
super(CapsNet, self).__init__()
if config:
self.conv_layer = ConvLayer(config.cnn_in_channels, config.cnn_out_channels, config.cnn_kernel_size)
self.primary_capsules = PrimaryCaps(config.pc_num_capsules, config.pc_in_channels, config.pc_out_channels,
config.pc_kernel_size, config.pc_num_routes)
self.digit_capsules = DigitCaps(config.dc_num_capsules, config.dc_num_routes, config.dc_in_channels,
config.dc_out_channels)
else:
self.conv_layer = ConvLayer()
self.primary_capsules = PrimaryCaps()
self.digit_capsules = DigitCaps()
self.mse_loss = nn.MSELoss()
def forward(self, data):
output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
return output
class Config_1:
def __init__(self, in_channels):
# CNN (cnn)
self.cnn_in_channels = in_channels
self.cnn_out_channels = 256
self.cnn_kernel_size = 3
# Primary Capsule (pc)
self.pc_num_capsules = 8
self.pc_in_channels = 256
self.pc_out_channels = 32
self.pc_kernel_size = 3
self.pc_num_routes = 32 * 6 * 6
# Digit Capsule (dc)
self.dc_num_capsules = 10
self.dc_num_routes = 32 * 6 * 6
self.dc_in_channels = 18
self.dc_out_channels = 1
class Config_2:
def __init__(self, in_channels):
# CNN (cnn)
self.cnn_in_channels = in_channels
self.cnn_out_channels = 256
self.cnn_kernel_size = 3
# Primary Capsule (pc)
self.pc_num_capsules = 8
self.pc_in_channels = 256
self.pc_out_channels = 32
self.pc_kernel_size = 3
self.pc_num_routes = 32 * 6 * 6
# Digit Capsule (dc)
self.dc_num_capsules = 10
self.dc_num_routes = 32 * 6 * 6
self.dc_in_channels = 8
self.dc_out_channels = 1
class Config_3:
def __init__(self, in_channels):
# CNN (cnn)
self.cnn_in_channels = in_channels
self.cnn_out_channels = 256
self.cnn_kernel_size = 3
# Primary Capsule (pc)
self.pc_num_capsules = 8
self.pc_in_channels = 256
self.pc_out_channels = 32
self.pc_kernel_size = 3
self.pc_num_routes = 32 * 6 * 6
# Digit Capsule (dc)
self.dc_num_capsules = 10
self.dc_num_routes = 32 * 6 * 6
self.dc_in_channels = 2
self.dc_out_channels = 1
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(1,32,kernel_size = (3,3))
self.conv2 = nn.Conv2d(32,48,kernel_size = (3,3))
self.conv3 = nn.Conv2d(48,64,kernel_size = (3,3))
self.conv4 = nn.Conv2d(64,80,kernel_size = (3,3))
self.conv5 = nn.Conv2d(80,96,kernel_size = (3,3))
self.conv6 = nn.Conv2d(96,112,kernel_size = (3,3))
self.conv7 = nn.Conv2d(112,128,kernel_size = (3,3))
self.conv8 = nn.Conv2d(128,144,kernel_size = (3,3))
self.conv9 = nn.Conv2d(144,160,kernel_size = (3,3))
self.caps_a = CapsNet(Config_1(64))
self.caps_b = CapsNet(Config_2(112))
self.caps_c = CapsNet(Config_3(160))
self.merge_weight1= nn.Parameter(torch.randn(1))
self.merge_weight2= nn.Parameter(torch.randn(1))
self.merge_weight3= nn.Parameter(torch.randn(1))
self.relu = nn.ReLU()
def forward(self,x):
branch1 = self.conv1(x)
branch1 = self.relu(branch1)
branch1 = self.conv2(branch1)
branch1 = self.relu(branch1)
branch1 = self.conv3(branch1)
branch1 = self.relu(branch1)
branch2 = self.conv4(branch1)
branch2 = self.relu(branch2)
branch2 = self.conv5(branch2)
branch2 = self.relu(branch2)
branch2 = self.conv6(branch2)
branch2 = self.relu(branch2)
branch3 = self.conv7(branch2)
branch3 = self.relu(branch3)
branch3 = self.conv8(branch3)
branch3 = self.relu(branch3)
branch3 = self.conv9(branch3)
branch3 = self.relu(branch3)
branch1_out = self.caps_a(branch1)
branch2_out = self.caps_b(branch2)
branch3_out = self.caps_c(branch3)
branch1_out = torch.sqrt((branch1_out ** 2).sum(dim=2, keepdim=True)).view(-1,10)
branch2_out = torch.sqrt((branch2_out ** 2).sum(dim=2, keepdim=True)).view(-1,10)
branch3_out = torch.sqrt((branch3_out ** 2).sum(dim=2, keepdim=True)).view(-1,10)
out1 = self.merge_weight1 * branch1_out
out2 = self.merge_weight1 * branch2_out
out3 = self.merge_weight1 * branch3_out
stack = torch.stack([out1,out3,out3],dim = 0)
summed = torch.sum(stack,dim = 0)
return summed
net = Net()
if USE_CUDA:
net = net.cuda()
criterion = nn.CrossEntropyLoss()
lr = 0.01
optimizer = torch.optim.Adam(net.parameters(),lr = lr)
## Training of the network
for epoch in range(5):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
if USE_CUDA:
inputs = inputs.cuda()
labels = labels.cuda()
outputs = net(inputs)
loss = criterion(outputs, labels)
if torch.isnan(loss):
print("nan Loss")
running_loss = 0.0
continue
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(net.parameters(), 0.1)
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
print('Finished Training')
Here’s the sample output
[1, 100] loss: 2.301
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
............
Any idea on what is causing this or any suggestions on how to fix it please?