.requires_grad_(True) doesnt work

Hi all,

What could be the reason for the following:

with torch.no_grad():
    outputs = self.module(input)
  
assert isinstance(outputs, (torch.Tensor, tuple)), \
    f"Output must be a tensor or a tuple of tensors. Got instead: {outputs.type}"

print(outputs.type()) # prints torch.cuda.FloatTensor
outputs.requires_grad_(True)
print(outputs.requires_grad) # prints false

So I try to modify a forward pass so that it completes the network forward in no_grad mode and calculates the loss gradient only w.r. to the output of the network. But calling .requires_grad_(True) seems to have no effect, as I still get false for the print.

Hi Juuso!

This should work – outputs.requires_grad should print out as True.

Could you post a fully-self-contained, runnable script that reproduces your
issue, together with the output you get when you run it? Please also let us
know what version of pytorch you are using.

Best.

K. Frank

Hi @KFrank, this snippet is part of a larger program. I actually noticed that once I disabled the compiling of the training_step() function, changing of the requires_grad (during the forward()) works now. I’m not sure why it works this way however.

Here’s a code that reproduces the problem:

import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.activation = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x
    

def task(model_1, model_2, input, target_output):
    # Forward pass
    with torch.no_grad():
        intermediate_output = model_1(input)
    intermediate_output.requires_grad_()
    print(intermediate_output.requires_grad) # prints false
    output = model_2(intermediate_output)
    loss_fn = nn.MSELoss()
    loss = loss_fn(output, target_data)
    return loss


# Subsequent modules test
input_size = 1 # xy location
hidden_size = 32
output_size = 1 # rgb 

model_1 = SimpleNN(input_size, hidden_size, output_size)
model_2 = SimpleNN(input_size, hidden_size, output_size)

# Creating random input and target tensors
torch.manual_seed(42)  # For reproducibility
N = 100  # Batch size
input_data = torch.randn(N, input_size)
target_data = torch.randn(N, output_size)

task = torch.compile(task)

for i in range(100):
    loss = task(model_1, model_2, input_data, target_data)
    loss.backward()

If I just comment out the task = torch.compile(task), changing of the gradient requirement works fine.

Hi Juuso!

I can reproduce* your issue with the code you posted. I’m not that
knowledgeable about torch.compile, but this looks like a bug to me.

If you think that this github issue:

is the same as yours, you might add a comment to it, or if your issue looks
different, you might want to file a new github issue.

*) Apparently torch.compile doesn’t work on windows or with python 3.12+.
So I ended up reproducing your issue on linux with pytorch version 2.1.2, the
latest version I happened to have installed with python 3.11.

Best.

K. Frank

Late thank you @KFrank. I ended up issuing a bug: .requires_grad_(True) doesn't work with torch.compile · Issue #123713 · pytorch/pytorch · GitHub