Simple RCNN module at the end of the net: Inplace Modification Error

I am currently trying to implement a model to treat sequence of images. My model is quite simple, for every image of the sequence, it’s first treated by an identical deep CNN and then the output of each image pass a RCNN module to produce its final output. Since the RCNN module is only at the last layer and the deep CNN has quite a lot of parameters, I want to avoid updating the parameters after processing all the image of the sequence (to much memory usage). Instead, I tried to to back-propagate after processing every image frame and the hidden state of RCNN is saved for next input update. Here is the codes my program (some codes are pseudo for the purpose of clarity).

class ConvBatchAct(nn.Module):
    def __init__(self, in_channels, out_channels, activation='ReLU'):
        super(ConvBatchAct, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.act = get_activation(activation)

    def forward(self, x):
        out = self.conv(x)
        x_out = self.batchnorm(out)
        return self.act(x_out)

class DCNN(nn.Module):
    ConvBatchAct layer stack

class RCNNlayer(nn.Module):
    def __init__(self,in_ch, out_ch, h_ch, activation_y = 'ReLU',activation_h = 'ReLU'):
        self.conv1 = nn.Conv2d(in_ch,h_ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(h_ch,h_ch, kernel_size=3, padding=1)
        self.conv3 = ConvBatchAct(h_ch,out_ch, activation= activation_y)

        self.batchnorm = nn.BatchNorm2d(out_ch)
        self.batchnorm_h = nn.BatchNorm2d(h_ch)

        self.act_h = get_activation(activation_h,inplace_flag =False)

    def forward(self, x, h=None):
        # h is the hidden state 
        x_out = self.conv1(x)
        if h is not None:
            h_out = self.conv2(h)
            h_next = self.batchnorm_h(self.act_h(x_out+h_out))
            h_next = self.batchnorm_h(self.act_h(x_out))
        y_out = self.conv3(h_next)
        return y_out, h_next
class SeqNet(nn.Module):
    def __init__(self, in_ch, out_ch, h_ch):
        super(SeqNet, self).__init__()
        self.cnn = DCNN()
        self.rcnn = RCNNlayer(in_ch, out_ch, h_ch)
    def forward(self, x, h):
        out1 = self.cnn(x)
        out2 = self.rcnn(out1,h)
        return out2 
def train(loader, model, device, optimizer):
    for i, sample in enumerate(loader, 1):
        img_sq, mask_sq = sample 
        num_frame = img_sq.shape[1]
        total_loss = 0
        h_pre = None
        for f in range(num_frame):
            print('frame {}'.format(f))
            img_m = torch.unsqueeze(img_sq[:, f], 1)
            img_m = to_var(img_m, device)
            mask_m = torch.unsqueeze(mask_sq[:, f], 1)
            mask_m = to_var(mask_m, device)
            # frame forward 
            out, h_pre_t = model(img_m,h_pre)
            # frame backward
            frame_loss = DICE_loss(out,mask_m)
            h_pre = h_pre_t.detach().clone().requires_grad_(False)
            total_loss +=

data_loader = DataLoader() # custom function 
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = SeqNet(in_ch, out_ch, h_ch)
optimizer = torch.optim.Adam(params, lr=learning_rate)
train(loader, model, device, optimizer)

I got the error in the forward of RCNNlayer

y_out = self.conv3(h_next)

in particular in the forward of ConvBatchAct

x_out = self.batchnorm(out)

At first, I put the retain_graph = False, however I got the buffer freed error

Trying to backward through the graph a second time, but the buffers have already been freed.

Then I change the retain_graph to be True, i got another error.

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2]] is at version 2; expected version 1 instead.

The ConvBatchAct is main block of the DCNN part and it raised no error. However, in this RCNN layer, it has raised this kind of error and I don’t know how to solve it. I also tried to set the nn.ReLu(inplace =False), but the error still exists. I am quite confused now and don’t know what to do.

If you have advice regarding to my codes or my method of backpropagation, please don’t hesitate to reply! Thank you in advance!