Module within Module Backprop (NTM Implementation)

I’m a bit of a newbie, and I’m having what I think may be a fairly trivial problem. I’m trying to get my model to backprop properly, but it simply refuses to. Here’s a rough example of what my code does (the full code is here):

import torch
import torch.nn as nn
import torch.nn.functional as Funct
from torch.autograd import Variable
import torch.optim as optim

class EMM_NTM(nn.Module):
    def __init__(self, *params):
        # init hidden state, etc. Make sure all Variables have requires_grad=True

    def forward(self, h):
        # pass forward through the memory module (this one is really long)

class FeedForwardController(nn.Module):
    def __init__(self,
                 num_inputs,
                 num_hidden,
                 batch_size,
                 num_reads=1,
                 memory_dims=(128, 20)):

        super(FeedForwardController, self).__init__()

        self.num_inputs = num_inputs
        self.num_hidden = num_hidden
        self.batch_size = batch_size
        self.memory_dims = memory_dims

        self.in_to_hid = nn.Linear(self.num_inputs, self.num_hidden)
        self.read_to_hid = nn.Linear(self.memory_dims[1]*num_reads, self.num_hidden)

    def forward(self, x, read):

        x = x.contiguous()
        x = x.view(-1, num_flat_features(x))
        read = read.contiguous()
        read = read.view(-1, num_flat_features(read))

        x = Funct.relu(self.in_to_hid(x) + self.read_to_hid(read))

        return x

class NTM(nn.Module):
        def __init__(self,
                 num_inputs,
                 num_hidden,
                 num_outputs,
                 batch_size,
                 num_reads,
                 memory_dims=(128, 20)):
        super(NTM, self).__init__()

        self.num_inputs = num_inputs
        self.num_hidden = num_hidden
        self.num_outputs = num_outputs
        self.batch_size = batch_size
        self.num_reads = num_reads
        self.memory_dims = memory_dims

        self.hidden = Variable(torch.rand(batch_size, self.num_hidden), requires_grad=True)

        self.EMM = EMM_NTM(self.num_hidden, self.batch_size, num_reads=self.num_reads,
                           num_shifts=3, memory_dims=self.memory_dims)
        # self.EMM.register_backward_hook(print)  # <- an attempt to see what's happening, this doesn't print

        self.controller = FeedForwardController(self.num_inputs, self.num_hidden, self.batch_size,
                                                num_reads=self.num_reads, memory_dims=self.memory_dims)
        # self.controller.register_backward_hook(print)  # <- this doesn't print either

        self.hid_to_out = nn.Linear(self.num_hidden, self.num_outputs)

    def forward(self, x):

        x = x.permute(1, 0, 2, 3)

        def step(x_t):
            r_t = self.EMM(self.hidden)

            # r_t.register_hook(print)  # <- this one doesn't print

            h_t = self.controller(x_t, r_t)
            h_t = h_t.view(-1, num_flat_features(h_t))

            # self.hidden.register_hook(print)  # <- this one prints

            self.hidden = Variable(h_t.data, requires_grad=True)
            out = Funct.sigmoid(self.hid_to_out(self.hidden))
            return out

        outs = torch.stack(
            [
                step(x_t) for x_t in torch.unbind(x, 0)
            ], 0)

        outs = outs.permute(1, 0, 2)

        return outs

For some reason when I call backwards it doesn’t look like the gradients are getting updated. I tried adding a bunch of backward hooks to see when it stops printing, and it looks like the backward calls just aren’t happening in the child modules. Any idea how to fix this?

The reason I think that backward is not getting called is because I checked the parameters at the beginning of the training and after 1000 inputs (and calls to loss.backward(), zeroing gradients each time) and they are equal. I also printed a set of parameters for the first ~100 iterations and they didn’t change at all.

I included the controller code because the same issue seems to be happening there as in the EMM_NTM code, I think the same fix should apply to both. Any help would be great - I’m quite confused!

2 Likes

I see that in your code you unpacked the .data multiple times. Remember that every time you do that, you won’t save that part of the computation in autograd and it might break the backprop. For example, if you meant _write_to_mem to be differentiable, then it’s not. You unpacked the Variable that contained the whole history and then repacked it in a new one, that won’t backprop gradients to the graph that created the data. The problem is in here.

Cheers - I took out all the .data calls and something changed - all the variables in EMM_NTM (memory, wr, ww) are now nan (which I’m guessing comes from trying to backprop, though I’m not sure). I registered a backward hook and it looks like a bunch of gradients/parameters for nn.Linear modules are getting initialized to nan which is weird.

Still worrisome is that the state doesn’t seem to change at all after 1-100 call(s) to optimizer.step() even with a large learning rate. The way I’m testing this is to save the state_dict at the very beginning of the training process and then compare it with the state_dict later on during training.

The core of my training loop looks like this:

ntm = NTM(num_inputs, num_hidden, num_inputs, batch, num_reads=1)

    try:
        ntm.load_state_dict(torch.load("models/copy_seqlen_{}.dat".format(seq_len)))
    except FileNotFoundError or AttributeError:
        pass

    ntm.train()

    state = ntm.state_dict()

    criterion = nn.MSELoss()
    optimizer = optim.RMSprop(ntm.parameters(), lr=5e-3, weight_decay=0.0005)

    max_seq_len = 20
    for length in range(10, max_seq_len):

        test = CopyTask(length, [num_inputs, 1], num_samples=2e4)

        data_loader = DataLoader(test, batch_size=batch, shuffle=True, num_workers=4)

        for epoch in range(5):

            for i, data in enumerate(data_loader, 0):
                inputs, labels = data
                inputs = Variable(inputs, requires_grad=True)
                labels = Variable(labels)

                optimizer.zero_grad()
                ntm.zero_grad()
                outputs = ntm(inputs)

                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                assert not (ntm.state_dict()['hid_to_out.bias'] == state['hid_to_out.bias'])[0]  # this just breaks it on the first loop

# do stuff with the outputs, plot, running loss, etc.