Recursive network don't work

I tried to make a recursive network. It is composed of convolution part and fully connected part, and would use previous output. The shape looks like this:

RecursiveModel(
  (conv_part): Sequential(
    (conv_0): Conv1d(1, 4, kernel_size=(128,), stride=(1,), padding=(32,))
    (conv_1): Conv1d(4, 4, kernel_size=(128,), stride=(1,), padding=(32,))
  )
  (full_part): Sequential(
    (full_conn_0): Linear(in_features=648, out_features=377, bias=True)
    (acti_lrelu_0): LeakyReLU(negative_slope=0.01)
    (full_conn_1): Linear(in_features=377, out_features=219, bias=True)
    (acti_lrelu_1): LeakyReLU(negative_slope=0.01)
    (full_conn_2): Linear(in_features=219, out_features=128, bias=True)
    (acti_lrelu_2): LeakyReLU(negative_slope=0.01)
  )
)

The forward part is:

def forward(self, x, prev_output):
        input_shape = x.size()
        assert(len(input_shape) == 3)
        assert(input_shape[0] == 1) # batch == 1
        assert(input_shape[1] == 1) # channel == 1
        assert(input_shape[2] == self.frame_size)

        conv_output = self.conv_part.forward(x)
        conv_output_plain = conv_output.view(1, self.conv_output_shape[0]*self.conv_output_shape[1])

        full_input = torch.cat( (conv_output_plain, prev_output), 1 )
        result = self.full_part.forward(full_input)
        return result

Where previous output will be concatenated with convolution output, and sent into fully connected part.

However, it would claim an error:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I also tried to detach previous output:

full_input = torch.cat( (conv_output_plain, prev_output.detach()), 1 )

This would run, but the error never converge.

At last, this is the training part of code, I cut input sequence into segments and each time run one segment:

def do_training():
    i_iter = 0

    while True:
        try:
            i_dataset = random.randrange(0, num_dset)
            dset_input = all_input[i_dataset]
            dset_refout = all_refout[i_dataset]

            num_frame = len(dset_input)
            assert(num_frame == len(dset_refout))
            prev_output = zero_output
            
            for i_frame_start in range(0, num_frame, options.batch_sz):
                
                i_iter += 1
                optimizer.zero_grad()
                loss = 0

                for j in range(options.batch_sz):
                    i_frame = i_frame_start + j
                    if i_frame >= num_frame:
                        break
                    pseudo_batch_input[0][0] = dset_input[i_frame] / 60.0 + 1.0 # normalize to -1 ~ 1
                    pseudo_batch_refout[0] = dset_refout[i_frame]

                    curr_output = model.forward(pseudo_batch_input.to(device), prev_output)
                    prev_output = curr_output
                    
                    curr_loss = cri( curr_output, pseudo_batch_refout.to(device) )
                    if math.isnan(float(curr_loss)):
                        pdb.set_trace()
                        continue
                    loss += curr_loss
                    curr_loss.backward()

                
                # run backward for each batch step
                loss_stat_re = loss_stat.record(float(loss))
                if loss_stat_re is not None:
                    line = str(i_iter) + "\t" + "\t".join(map(str, loss_stat_re))
                    logger.info(line)

                #loss.backward()
                optimizer.step()

                if i_iter > options.n_iter:
                    return
        except KeyboardInterrupt:
            print("early out by int at iter %d" % i_iter)
            return

Hi,

The problem is that prev_output is linked to the previous forward you did. And so if you call backward after the new forward, it will try to backward through the previous forward as well. Hence the error you see.

You can use retain_graph=True when you call .backward() to avoid this issue but then you will backward all the way to the first foward you made (and will most likely OOM quite quickly).
If you don’t want gradients to flow back into the previous forward you should use .detach() on prev_output to stop the gradient.

Thanks! I will try to use retain_graph=True as this seems like what I want. I will keep segments short to avoid too much memory use.