PyTorch Gradient Computation Fails When Not Using Entire Input Tensor

I have a model that takes an input tensor , along with other inputs k and D. The model outputs several tensors, including cs_hat. When I compute gradients of cs_hat with respect to the first slice of inputs (inputs[:,:,0]), the gradient computation only succeeds if I compute it with respect to the entire tensor inputs instead of just the slice.

Here is a simplified version of my code that illustrates the problem:

import torch
from torch import nn

class MyModel(torch.nn.Module):

    def __init__(self, input_size = 3 , ffn_size = 15, ffn_layers = 2, res_block_size = 15, res_block_layers = 2):
        super(MyModel, self).__init__()

        self.input_size = input_size
        self.activation = nn.LeakyReLU()

        self.ffn_size = ffn_size
        self.ffn_layers = ffn_layers
        self.res_block_size = res_block_size
        self.res_block_layers = res_block_layers

        self.linear_block_0 = self._make_linear_block(self.ffn_size, self.ffn_layers, input_size=self.input_size)

        self.final_layer_a = nn.Linear(self.res_block_size, 1, bias=False)
        self.final_layer_b = nn.Linear(self.res_block_size, 1, bias=False)
        self.final_layer_c = nn.Linear(self.res_block_size, 1, bias=False)
        self.final_layer_d = nn.Linear(self.res_block_size, 1, bias=False)




    def _make_linear_block(self, width, depth, input_size = None):

        if input_size is None:
            linear_block = nn.ModuleList([nn.Linear(width , width), self.activation])
        else:
            linear_block = nn.ModuleList([nn.Linear(input_size , width), self.activation])

        for _ in range(depth - 1):
            linear_block.append(nn.Linear(width, width))
            linear_block.append(self.activation)

        linear_block_ = nn.Sequential(*linear_block)

        return linear_block_


    def forward(self, inputs,k,D):

        t = inputs[:,:,0]
        x = inputs[:,:,1]

        input_t = torch.cat([t,k.view(-1,1),D.view(-1,1)],dim = -1)

        z0 = self.linear_block_0(input_t)

        a = self.final_layer_a(z0)
        b = self.final_layer_b(z0)
        c = self.final_layer_c(z0)
        d = self.final_layer_d(z0)

        return a,b,c,d


#Main

model = MyModel()

inputs = torch.tensor([[[0.4521, 0.5205]], [[0.3066, 0.6816]], [[0.0547, 0.9297]], [[0.3936, 0.9229]]], requires_grad=True) #supposed to be of size (batch_size,1 ,1)

batch_size = 4

k = torch.randn(batch_size, requires_grad=True)
D = torch.randn(batch_size, requires_grad=True)

# Forward pass
outputs = model(inputs, k, D)
cs_hat = outputs[2]  # Assuming cs_hat is the third output

# Gradient computation that works
cs_dt = torch.autograd.grad(cs_hat, inputs, grad_outputs=torch.ones_like(cs_hat), create_graph=True)[0]
# Gradient computation that fails
cs_dt = torch.autograd.grad(cs_hat, inputs[:,:,0], grad_outputs=torch.ones_like(cs_hat), create_graph=True)[0]

My error message:

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

How can I fix this without using the whole having to use the whole input sequence to compute gradients?

This is expected since slicing a tensor is a differentiable operation creating a computation graph.
In your case the output of inputs[:,:,0] was indeed never used to create cs_hat.
You could explicitly create the slice and recreate a stacked tensor passed to the model as seen here:

a = inputs[:,:,0]
b = inputs[:,:,1]

c = torch.stack((a, b), dim=2)

# Forward pass
outputs = model(c, k, D)
cs_hat = outputs[2]  # Assuming cs_hat is the third output

cs_dt = torch.autograd.grad(cs_hat, c, grad_outputs=torch.ones_like(cs_hat), create_graph=True)[0]
cs_dt = torch.autograd.grad(cs_hat, a, grad_outputs=torch.ones_like(cs_hat), create_graph=True)[0]

This allows you to compute the gradients w.r.t. a.

1 Like

Fixed it! Thank you! In the meanwhile I just implemented the whole thing by deriving for the input tensor and sorting things out after. My loss works for a test now, but it becomes nan during training (and pretty quickly at that), does that imply vanishing gradients?

No, a NaN loss would not be necessarily created via vanishing gradients (since the parameters would not be updated at all), so you should narrow down which operation causes the NaN (or Inf) output in the forward pass.

Fixed it! Thank you again, haha. It was due to one value in my input dataset being zero.