Remove custom layer from backprop graph

I am using the following Class to implement Tiling as in Caffe

class Tiling(nn.Module):
    # Tiling layer from orginal caffe-VPGNet
    def __init__(self, x_dim, tile_dim):
        super(Tiling, self).__init__()
        self.b, d, self.h, self.w = x_dim
        self.tile_dim = tile_dim
        self.out_d = int(d / (tile_dim ** 2))  # 4
        out_h = int(self.h * tile_dim)  # 15
        out_w = int(self.w * tile_dim)  # 20
        self.tiled_out = (
            torch.FloatTensor(self.b, self.out_d, out_h, out_w)
            .fill_(0.0)
            .to(torch.device("cuda:0"))
        )

    def forward(self, x):
        #print("x ",x.requires_grad)
        for ds in range(self.out_d):
            d_start = ds * (self.tile_dim ** 2)
            d_end = (ds + 1) * (self.tile_dim ** 2)
            for hs in range(self.h):
                for ws in range(self.w):
                    tile_select = x[:, d_start:d_end, hs, ws]
                    #          tile_select = x[:, ds*(self.tile_dim**2):(ds+1)*(self.tile_dim**2), hs, ws]
                    out_tile = tile_select.view(self.b, self.tile_dim, self.tile_dim)
                    h_start = hs * self.tile_dim
                    h_end = (1 + hs) * self.tile_dim
                    w_start = ws * self.tile_dim
                    w_end = (1 + ws) * self.tile_dim
                    self.tiled_out[:, ds, h_start:h_end, w_start:w_end] = out_tile

        #print("tiled_out ",self.tiled_out.requires_grad)

        return self.tiled_out

While trying to train the model with this layer at the head, I get the following error

    total_loss.backward()
  File "/home/SandeepMenon/venv-e3d/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/SandeepMenon/venv-e3d/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

My operations in the Tiling layer should not be considered in the computation graph for backprop and should be considered as a custom Reshape layer.
I tried to detach() the tensor before assignment by changing the following line
tile_select = x[:, d_start:d_end, hs, ws]
to
tile_select = x[:, d_start:d_end, hs, ws].detach()

However, then I get the following error

>     total_loss.backward()
>   File "/home/SandeepMenon/venv-e3d/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "/home/SandeepMenon/venv-e3d/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
>     allow_unreachable=True)  # allow_unreachable flag
> RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

How should I re-write my Tiling layer to resolve this.

Thank you

In your current code snippet x_detach isn’t defined, so I’m not sure if you just assigned it to x or have indeed detached it.
Note that, in the latter case, the computation graph would be detached and previous layers will not get any gradients, if they depend on x.

I’m not sure what your exact use case is, but are you getting the first error message by using x directly?
If so, could you post an executable code snippet using random input tensors to reproduce this issue?

@ptrblck
Thank you for the reply and apologies for my typo. I renamed x_detach to x.
My use case is to implement Tiling layer as in the caffe tiling layer mentioned in this repo

And yes I am getting the first error message by using x directly.
Code for former case

import torch
import torch.nn as nn

class Tiling(nn.Module):
    # Tiling layer from original caffe-VPGNet
    def __init__(self, x_dim, tile_dim):
        super(Tiling, self).__init__()
        self.b, d, self.h, self.w = x_dim
        self.tile_dim = tile_dim
        self.out_d = int(d / (tile_dim ** 2))  # 4
        self.out_h = int(self.h * tile_dim)  # 15
        self.out_w = int(self.w * tile_dim)  # 20
        self.tiled_out = (
            torch.FloatTensor(self.b, self.out_d, self.out_h, self.out_w)
            .fill_(0.0)
            .to(torch.device("cuda:0"))
        )

    def forward(self, x):
        # tiled_out = (
        #     torch.FloatTensor(self.b, self.out_d, self.out_h, self.out_w)
        #     .fill_(0.0)
        # ).to(torch.device("cuda:0"))

        for ds in range(self.out_d):
            d_start = ds * (self.tile_dim ** 2)
            d_end = (ds + 1) * (self.tile_dim ** 2)
            for hs in range(self.h):
                for ws in range(self.w):
                    tile_select = x[:, d_start:d_end, hs, ws]
                    out_tile = tile_select.view(self.b, self.tile_dim, self.tile_dim)
                    h_start = hs * self.tile_dim
                    h_end = (1 + hs) * self.tile_dim
                    w_start = ws * self.tile_dim
                    w_end = (1 + ws) * self.tile_dim
                    self.tiled_out[:, ds, h_start:h_end, w_start:w_end] = out_tile

        return self.tiled_out

class Test_Tiling(nn.Module):
    def __init__(self, batch_size=2):
        super().__init__()
        
        self.layer1 = nn.Conv2d(320, 320, kernel_size=1)
        tile_d = Tiling([batch_size, 320, 15, 20], 8)
        self.vpp = nn.Sequential(*[self.layer1, tile_d])

    def forward(self,x):
        out_1 = self.vpp(x)

        return out_1

if __name__ == "__main__":
    batch_size = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = Test_Tiling(batch_size=batch_size).to(device)
    loss_fn = nn.BCEWithLogitsLoss()
    optim_fn = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

    for epoch in range(5):
        inp = torch.FloatTensor(batch_size, 320, 480, 640).cuda()
        ground_truth = torch.FloatTensor(batch_size,5,120,160).cuda()
        y = model(inp)

        total_loss = loss_fn(y, ground_truth)

        # Zero optimizer
        optim_fn.zero_grad()

        # Backprop
        total_loss.backward()

        # Optimize weights
        optim_fn.step()

        print(f"Epoch:[{epoch}/{5}] loss:[{total_loss:.4f}] ")

Code for latter case (Change only in the Tiling class).
I added .detach() while reading from x and defining tiled_out inside forward function to not use the saved results as per the error message(not sure if this is doing it or not, but the first error goes away)

class Tiling(nn.Module):
    # Tiling layer from original caffe-VPGNet
    def __init__(self, x_dim, tile_dim):
        super(Tiling, self).__init__()
        self.b, d, self.h, self.w = x_dim
        self.tile_dim = tile_dim
        self.out_d = int(d / (tile_dim ** 2))  # 4
        self.out_h = int(self.h * tile_dim)  # 15
        self.out_w = int(self.w * tile_dim)  # 20
        # self.tiled_out = (
        #     torch.FloatTensor(self.b, self.out_d, self.out_h, self.out_w)
        #     .fill_(0.0)
        #     .to(torch.device("cuda:0"))
        # )

    def forward(self, x):
        tiled_out = (
            torch.FloatTensor(self.b, self.out_d, self.out_h, self.out_w)
            .fill_(0.0)
        ).to(torch.device("cuda:0"))

        for ds in range(self.out_d):
            d_start = ds * (self.tile_dim ** 2)
            d_end = (ds + 1) * (self.tile_dim ** 2)
            for hs in range(self.h):
                for ws in range(self.w):
                    tile_select = x[:, d_start:d_end, hs, ws].detach()
                    out_tile = tile_select.view(self.b, self.tile_dim, self.tile_dim)
                    h_start = hs * self.tile_dim
                    h_end = (1 + hs) * self.tile_dim
                    w_start = ws * self.tile_dim
                    w_end = (1 + ws) * self.tile_dim
                    tiled_out[:, ds, h_start:h_end, w_start:w_end] = out_tile

        return tiled_out

Thanks for the update.
The error in the first approach is raised, since you are storing the self.tiled_out tensor inside the module and assign activations to it.
If I understand the use case correctly, you are pre-allocating tiled_out only to store the results and don’t care about its initial values. If that’s the case, then the second approach looks correct.

@ptrblck
Thank you for the confirmation.
For the second approach, the code snippet above throws the
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn when I have the line
tile_select = x[:, d_start:d_end, hs, ws].detach()
Instead if I remove the detach
tile_select = x[:, d_start:d_end, hs, ws]
it runs. Is that correct?

I cannot reproduce the mentioned error using:

class Tiling(nn.Module):
    # Tiling layer from original caffe-VPGNet
    def __init__(self, x_dim, tile_dim):
        super(Tiling, self).__init__()
        self.b, d, self.h, self.w = x_dim
        self.tile_dim = tile_dim
        self.out_d = int(d / (tile_dim ** 2))  # 4
        self.out_h = int(self.h * tile_dim)  # 15
        self.out_w = int(self.w * tile_dim)  # 20

    def forward(self, x):
        tiled_out = (
            torch.FloatTensor(self.b, self.out_d, self.out_h, self.out_w)
            .fill_(0.0)
        ).to(device)

        for ds in range(self.out_d):
            d_start = ds * (self.tile_dim ** 2)
            d_end = (ds + 1) * (self.tile_dim ** 2)
            for hs in range(self.h):
                for ws in range(self.w):
                    tile_select = x[:, d_start:d_end, hs, ws]
                    out_tile = tile_select.view(self.b, self.tile_dim, self.tile_dim)
                    h_start = hs * self.tile_dim
                    h_end = (1 + hs) * self.tile_dim
                    w_start = ws * self.tile_dim
                    w_end = (1 + ws) * self.tile_dim
                    tiled_out[:, ds, h_start:h_end, w_start:w_end] = out_tile

        return tiled_out

and can perform the backward pass:

Epoch:[0/5] loss:[0.6923]
Epoch:[1/5] loss:[-18526809096192.0000]
Epoch:[2/5] loss:[13040161335951413175386112.0000]
Epoch:[3/5] loss:[-16155749540553689246203904.0000]

The loss is blowing up, but that might be unrelated to the error you are seeing.

Yes the code you posted runs without errors. I added the .detach() function while selecting the tile from x, thinking it is required. That is when the error happens. Wanted to understand why detaching does not work.
The code below throws the error RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

class Tiling(nn.Module):
    # Tiling layer from original caffe-VPGNet
    def __init__(self, x_dim, tile_dim):
        super(Tiling, self).__init__()
        self.b, d, self.h, self.w = x_dim
        self.tile_dim = tile_dim
        self.out_d = int(d / (tile_dim ** 2))  # 4
        self.out_h = int(self.h * tile_dim)  # 15
        self.out_w = int(self.w * tile_dim)  # 20

    def forward(self, x):
        tiled_out = (
            torch.FloatTensor(self.b, self.out_d, self.out_h, self.out_w)
            .fill_(0.0)
        ).to(device)

        for ds in range(self.out_d):
            d_start = ds * (self.tile_dim ** 2)
            d_end = (ds + 1) * (self.tile_dim ** 2)
            for hs in range(self.h):
                for ws in range(self.w):
                    tile_select = x[:, d_start:d_end, hs, ws].detach()
                    out_tile = tile_select.view(self.b, self.tile_dim, self.tile_dim)
                    h_start = hs * self.tile_dim
                    h_end = (1 + hs) * self.tile_dim
                    w_start = ws * self.tile_dim
                    w_end = (1 + ws) * self.tile_dim
                    tiled_out[:, ds, h_start:h_end, w_start:w_end] = out_tile

        return tiled_out

Thank you

@ptrblck
Were you able to reproduce the error of the latest snippet?

I haven’t executed the new code snippet, but it’s expected that detaching the tensor from the computation graph stops the gradient propagation at this point, so I’m unsure why you want to use it there.