Re-implementation of time-distributed CNN from Keras to Pytorch does not converge

Hi there,

I’ve trained a neural network (transfer-learned Densenet161) to categorise individual frames of a video which it can do with a validation accuracy of ~ 93%.

I’ve then made npz files of the first 12 frames of the video (exact same resolution and pixel range of 0 to 1) and am now training a time-distributed network which takes the final featuremaps of the CNN and puts them through a 1D CNN categorise videos (so 12 time steps, 1 per frame, for each feature, which is averagepooled prior).

In Keras this works well, and it done by:

cnn = load_model(cnn_model_path)
cnn = Model(inputs=cnn.layers[0].input, outputs=cnn.layers[-2].output)
    for layer in cnn.layers:
        layer.trainable = False
model.add(TimeDistributed((cnn), input_shape=(self.frames, self.height, self.width, self.channels), name='cnn'))
model = Sequential(name='spatial')
model.add(SeparableConv1D(256, kernel_size=3)) 
model.add(MaxPool1D(2)) 
model.add(SeparableConv1D(128, kernel_size=3))
model.add(Flatten())
model.add(Dense(self.generator.n_classes, activation='softmax', name='output'))
model.compile(loss=categorical_crossentropy, optimizer='adam', metrics=['accuracy'])

So I’ve been trying to re-implement this in Pytorch, but it plateaus at 16% validation accuracy (with 16% being the prevalence of the largest group, so it’s just coin flipping):

Here’s my network class:

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

class TD_CNN(nn.Module):
    def __init__(self, inner_model, n_classes):
        super(TD_CNN, self).__init__()
        self.inner_model = inner_model
        self.inner_model.classifier = Identity() # Remove classifier

        self.cnn1d1depth = nn.Conv1d(2208, 2208, kernel_size=3, groups=2208)
        self.cnn1d1point = nn.Conv1d(2208, 256, kernel_size=1)

        self.cnn1d2depth = nn.Conv1d(256, 256, kernel_size=3, groups=256)
        self.cnn1d2point = nn.Conv1d(256, 128, kernel_size=1)

        self.classifier = nn.Linear(3 * 128,n_classes)

    def forward(self, x):
        batch_size, timesteps, channels, height, width = x.size()
        inner_model_in = x.view(batch_size * timesteps, channels, height, width)
        inner_model_out = self.inner_model(inner_model_in) # Yields (batch_size * timesteps) x 2208
        x = inner_model_out.view(batch_size, 2208, timesteps)

        x = self.cnn1d1depth(x)
        x = F.relu(self.cnn1d1point(x))
        x = F.max_pool1d(x, 2)

        x = self.cnn1d2depth(x)
        x = F.relu(self.cnn1d2point(x))

        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return F.log_softmax(x, dim=1)

And here’s how I use it:

INNER_MODEL_TYPE = models.densenet161
INNER_MODEL_WEIGHTS = "DenseNet161_nonorm_e005_a0.936.model"

...

inner_model = INNER_MODEL_TYPE(pretrained=True)
n_features = inner_model.classifier.in_features
inner_model.classifier = nn.Linear(n_features, N_CLASSES) # Remnant from training the CNN, not relevant for the TD, but allows us to load the weights
if INNER_MODEL_WEIGHTS:
    inner_model.load_state_dict({k.split('module.',1)[1]:v for k,v in torch.load(INNER_MODEL_WEIGHTS).items()})
for param in inner_model.parameters():
    param.requires_grad = False
model = TD_CNN(inner_model=inner_model, n_classes=N_CLASSES)

model = model.to(DEVICE)
summary(model, input_size=(N_FRAMES, 3, 299, 299))
if MULTI_GPU:
    model = nn.DataParallel(model)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(get_trainable(model.parameters()))

And here’s the DataFolder I’ve used for loading the npz files, which are in the format rootdir/<label_name>/sample1.npz:

class npyDataset(DatasetFolder):
    def __init__(self, root):
        super(npyDataset, self).__init__(root, self.npz_loader, [".npz"])

    def npz_loader(self, path):
        npy = np.load(path)
        try:
            npy = npy[npy.files[0]]  # If an npz file we need to get the data out using the filename as a key
        except:
            pass
        npy = npy[:12]
        x = torch.from_numpy(npy.reshape((-1, 3, 299, 299))).float()
        return x

Here’s the Pytorch summary (using the torchsummary module):

Using device cuda:0
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 96, 150, 150]          14,112
       BatchNorm2d-2         [-1, 96, 150, 150]             192
              ReLU-3         [-1, 96, 150, 150]               0
         MaxPool2d-4           [-1, 96, 75, 75]               0
       BatchNorm2d-5           [-1, 96, 75, 75]             192
              ReLU-6           [-1, 96, 75, 75]               0
            Conv2d-7          [-1, 192, 75, 75]          18,432
       BatchNorm2d-8          [-1, 192, 75, 75]             384
              ReLU-9          [-1, 192, 75, 75]               0
           Conv2d-10           [-1, 48, 75, 75]          82,944
      BatchNorm2d-11          [-1, 144, 75, 75]             288
             ReLU-12          [-1, 144, 75, 75]               0
           Conv2d-13          [-1, 192, 75, 75]          27,648
      BatchNorm2d-14          [-1, 192, 75, 75]             384
             ReLU-15          [-1, 192, 75, 75]               0
...
            ReLU-483            [-1, 192, 9, 9]               0
          Conv2d-484             [-1, 48, 9, 9]          82,944
     BatchNorm2d-485           [-1, 2208, 9, 9]           4,416
        Identity-486                 [-1, 2208]               0
        DenseNet-487                 [-1, 2208]               0
          Conv1d-488             [-1, 2208, 10]           8,832
          Conv1d-489              [-1, 256, 10]         565,504
          Conv1d-490               [-1, 256, 3]           1,024
          Conv1d-491               [-1, 128, 3]          32,896
          Linear-492                   [-1, 14]           5,390
================================================================
Total params: 27,085,646
Trainable params: 613,646
Non-trainable params: 26,472,000
----------------------------------------------------------------
Input size (MB): 12.28
Forward/backward pass size (MB): 927.82
Params size (MB): 103.32
Estimated Total Size (MB): 1043.42
----------------------------------------------------------------
Failed to save model graph: tuple appears in op that does not forward tuples (VisitNode at ..\torch\csrc\jit\passes\lower_tuples.cpp:109)
(no backtrace available)
C:\Users\James\Miniconda3\envs\pytorch\lib\site-packages\torch\cuda\nccl.py:24: UserWarning: PyTorch is not compiled with NCCL support
  warnings.warn('PyTorch is not compiled with NCCL support')
EPOCH: 001 | BATCH: 1092 of 9996 | LOSS: 2.180 (2.478) | (0.17 s/it; ETA 0:25:23)
Process finished with exit code 0

I’m not sure what the warning message is, but I assume it’s just saying it’s not playing well with TensorboardX?

It works SO well in Keras, and just not at all in Pytorch, which makes me think I’ve made a fundamental mistake, but can’t think what.

Thanks!

PS the full code for the training file is at https://pastebin.com/ba522gta if that helps.

hi, im stack with the same problem, can you solved?