"RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed" when training

I’ve been trying to implement a version of an hourglass model but during training I keep running into this error:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

I’ve read the other posts, but unless I’m mistaken none of them solve my problem.

Beyond a custom loss function , I don’t see any reason why it would behave like this since as far as I can tell I’m doing a regular training loop. Although I might have missed something since I’m pretty new to

I’ve been debugging for a couple days now and I haven’t been able make any headway, so someone here might be able to point me in the right direction?

A lot of the training and supplementary code including the loss function i got from here:

Training Code:

for epoch in range(epochs):
    with tch.no_grad():
        lr = adjust_learning_rate(optimizer, epoch, lr, [75, 100, 150], 0.34199518933)
    for i, (inputload, targetload, meta) in enumerate(training_loader):
        input, target = inputload.to(device), targetload.to(device, non_blocking=True)
        target_weight = meta['target_weight'].to(device, non_blocking=True)
        with tch.autograd.detect_anomaly():
            output = model(input)
        print("target", target.size())
        print("output", output.size())
        loss = criterion(output, target, target_weight)
        input.grad = None
        target.grad = None
        target_weight.grad = None
        output.grad = None

Loss function:

# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# Modified by Wei Yang (platero.yang@gmail.com)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch.nn as nn

class JointsMSELoss(nn.Module):
    def __init__(self, use_target_weight=True):
        super(JointsMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='mean')
        self.use_target_weight = use_target_weight

    def forward(self, output, target, target_weight):
        batch_size = output.size(0)
        num_joints = output.size(1)
        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
        loss = 0

        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                loss += 0.5 * self.criterion(
                    heatmap_pred.mul(target_weight[:, idx]),
                    heatmap_gt.mul(target_weight[:, idx])
                loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
        return loss / num_joints

Model Code:

import numpy as np 
import pandas as pd 
import torch as tch
import torch.nn as nn
import pickle

class ConvBlock(nn.Module):
    def __init__(self, n, m):
        super(ConvBlock, self).__init__()
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(n)
        print("conv", n, m)
        self.conv = nn.Conv2d(n, m, 3, padding=1)
    def forward(self, x):
        return y
class Gate(nn.Module):
    def __init__(self, n):
        super(Gate, self).__init__()
        self.module = nn.Linear(n, n, bias=False)
    def forward(self, x):
        x = x.permute(0,2,3,1)
        x = self.module(x)
        x = x.permute(0,3,1,2)
        return x
class Residual(nn.Module):
    def __init__(self,n):
        super(Residual, self).__init__()
        self.conv1 = ConvBlock(n,n//2)
        self.conv2 = ConvBlock(n//2,n//4)
        self.conv3 = ConvBlock(n//4,n//4)
        self.gate = Gate(n)
    def forward(self,x):
        y1 = self.conv1(x)
        y2 = self.conv2(y1)
        y3 = self.conv3(y2)
        print("res", y1.size(), y2.size(), y3.size())
        y_c = tch.cat((y1,y2,y3),1)
        y = tch.add(y_c, self.gate(x))
        return y
class ResUp(nn.Module):
    def __init__(self,n):
        super(ResUp, self).__init__()
        self.res = Residual(n)
    def forward(self, x, target):
        self.upsample = nn.Upsample(target, mode='nearest')
        return self.res(self.upsample(x))
class Merge(nn.Module):
    def __init__(self,n):
        super(Merge, self).__init__()
        self.conv = nn.Conv2d(2*n, n, 3, padding=1)
    def forward(self, x, y):
        print("merge", x.size(), y.size())
        return self.conv(tch.cat((x,y),1))
class Hourglass(nn.Module):
    def __init__(self,n):
        super(Hourglass, self).__init__()
        self.resdowns = nn.ModuleList([Residual(n) for i in range(4)])
        self.pool = nn.MaxPool2d(2)
        self.resups = nn.ModuleList([ResUp(n) for i in range(4)])
        self.ress = nn.ModuleList([Residual(n) for i in range(4)])
        self.merges = nn.ModuleList([Merge(n) for i in range(4)])
        self.skips = []
    def forward(self, x):
        print("input size", x.size())
        for i in range(4):
            x = self.resdowns[i](x)
            x = self.pool(x)
            print("skip in", self.skips[i].size(), i)
        for i in range(4):
            print("skip out", i)
            x = self.merges[i](x,self.pool(self.skips[3-i]))
            x = self.resups[i](x, self.skips[3-i].size(2))
            print("resup", x.size())
        print("output size", x.size())
        return x
class Model(nn.Module):
    def __init__(self,in_channels, hidden_channels, numjoints, num_hg):
        super(Model, self).__init__()
        self.num_hg = num_hg
        self.conv = nn.Conv2d(in_channels, numjoints, 1)
        self.modulelist = nn.ModuleList([nn.Sequential(
            nn.Conv2d(numjoints, hidden_channels, 1),
            nn.Conv2d(hidden_channels, numjoints, 1),
        ) for i in range(num_hg)])
    def forward(self, x):
        x = self.conv(x)
        for i in range(self.num_hg):
        return x

I guess appending to self.skips and using it afterwards could create the issue.
Could you re-initialize self.skips = [] inside the forward method for the sake of debugging and check, if this would be working?

That fixed it, thank you so much!