[Help Needed] Correct gradient passing with view() or reshape()?

import torch
from torch.nn.modules.module import Module
import numpy as np
import torch.nn.functional as F

class fCGFunction(torch.autograd.Function):
    def __init__(self):
        super(fCGFunction, self).__init__()

    def forward(self, activations):
        self.save_for_backward(activations)
        return torch.zeros(activations.shape, dtype=torch.float).cuda()

    def backward(self, grad_output):
        print("grad_output", grad_output)
        activations = self.saved_tensors[0]
        grad_input = torch.zeros(activations.shape,dtype=torch.float).cuda()
        return grad_input

class fCGModule(Module):
    def __init__(self):
        super(fCGModule, self).__init__()

    def forward(self, activations):
        batch_size = activations[0].shape[0]
        new_activations = torch.cat([v.view(batch_size,-1, 2).clone() for v in activations], dim=1).float()
        print(new_activations.requires_grad)
        output = fCGFunction()(new_activations)
        return [output[:,(3*i):(3+3*i),:].view(batch_size,1,3,2) for i in range(2)]

BATCH_SIZE=1
v1 = [torch.tensor(np.random.rand(BATCH_SIZE, 1,3,2),
                    dtype=torch.float,requires_grad=True).cuda() for i in range(2)]
w1 = fCGModule()(v1)
print(w1)

tot = 0
for lst in w1:
    tot += torch.norm(lst)
loss = F.mse_loss(tot, torch.zeros([1]).cuda())
loss.backward()
print([i.grad for i in v1])

The above code gives me a list of None when I print gradients in the end. Bascially, I need to convert a list into a big tensor, pass it to some kernel, and make it back to a list of tensors. However, the gradient does not seem to be correctly passed down. How should I do this?

Hi,

The problem is that the tensor you check the gradients of is not the one you require gradients for. The .cuda() call returns a different Tensor. You can do the following:

device = torch.device('cuda')
BATCH_SIZE=1
v1 = [torch.tensor(np.random.rand(BATCH_SIZE, 1,3,2),
                    dtype=torch.float, device=device, requires_grad=True) for i in range(2)]
1 Like

Thank you so much!!! This step is very critical in my codes!