I applied CTC loss for the continuous sign-language recognition task. I use the model architecture proposed in SubUNets.
I used ResNet18 for the CNN model to spatial feature extraction. There’s no problem here but ResNet18 consumes very high memory.
So, I changed The CNN model to be AlexNet for lighter weight.
but when I use AlexNet for the spatial feature extraction model, after few epochs the CTC loss function produced ‘Nan’ value.
I’m not sure where the problem is due to I used the same parameter in both models.
only one part that two models are different is a model part.
please check my code for both models.
this is code for AlexNet
class CnnEncode(nn.Module):
def __init__(self, cnn_embed_dim=128, hidden = (128,128,128), drop_p=0.3):
super(CnnEncode, self).__init__()
self.cnn_embed_dim = cnn_embed_dim
self.h1 , self.h2, self.h3 = hidden
self.drop_p = drop_p
alexnet = models.alexnet(pretrained=False)
alexnet.features[0] = nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
modules = [list(alexnet.children())[0]] + [list(alexnet.children())[1]]
self.alexnet = nn.Sequential(*modules)
def forward(self, x_3d):
cnn_embed_seq = []
print(x_3d.shape)
for t in range(x_3d.size(1)):
x = x_3d[:, t, :, :,:].to(device=device)
x = F.relu(self.alexnet(x))
x = x.view(x.size(0), -1)
x = x.to(device=cpu)
cnn_embed_seq.append(x)
cnn_embed_seq = torch.stack(cnn_embed_seq, dim=2)
x = cnn_embed_seq.permute(2,0,1)
return x
this is code for ResNet18
class CnnEncode(nn.Module):
def __init__(self, cnn_embed_dim=128, hidden = (128,128,128), drop_p=0.3):
super(CnnEncode, self).__init__()
self.cnn_embed_dim = cnn_embed_dim
self.h1 , self.h2, self.h3 = hidden
self.drop_p = drop_p
resnet = models.resnet18(pretrained=False)
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
modules = list(resnet.children())[:-1] #delete the last fc layer.
self.resnet = nn.Sequential(*modules)
def forward(self, x_3d):
cnn_embed_seq = []
print(x_3d.shape)
for t in range(x_3d.size(1)):
x = x_3d[:, t, :, :,:].to(device=device)
x = F.relu(self.resnet(x))
x = x.view(x.size(0), -1)
x = x.to(device=cpu)
cnn_embed_seq.append(x)
cnn_embed_seq = torch.stack(cnn_embed_seq, dim=2)
x = cnn_embed_seq.permute(2,0,1)
return x