Zero'ing one input's gradient for Matrix Multiply

Hi! the @ operator, or matrix multiply, is stateless and accepts 2 input tensors. During the backprop, my understanding is that it’ll calculate two gradients w.r.t. the 2 input tensors, which will each update any variables via the chain rule along the paths that produce them, respectively.

In my experiment, I want to zero out the gradient for only one input tensor and keep the other as-is. My guess is that I need to use autograd, but I’m fairly new to this and the toy example (PyTorch: Defining New autograd Functions — PyTorch Tutorials 1.7.0 documentation) doesn’t seem suffice. Please help!

I figured I could use the tutorial to add a new custom ‘Identity’ layer that passes through activations in forward prop but zeros out the gradient in back prop. But I’m really curious about how to do it properly with bmm (@). Thanks.

It turns out that I also need to modify the gradient of the other operand. Is there any example or tutorial that shows how to use Autograd function with an operation that produces multiple gradients? TIA.

class FwdOnlyIdentity(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # pass the input feature map unmodified, as Identity
        return input
    
    @staticmethod
    def backward(ctx, grad_output):
        zero_grad_input = torch.zeros_like(grad_output) # set gradient w.r.t. input to 0's
        return zero_grad_input

Above is that I implemented for what I originally asked about.

Then, instead of directly doing a @ b, I added self.fwd_identity = FwdOnlyIdentity.apply. Then in the forward function, the computation becomes self.fwd_identity(a) @ b. Does this make sense?

in the simplest case, you can just do something like x.detach() @ weights

1 Like

The matrix multiply here doesn’t have any trainable weights. It takes two activation tensors as input.

that’s arbitrary variable names, this cuts off the whole computation chain that produced “x”, if you return a gradient of zeros, zeros will [back] propagate further, so it is the same as blocking with detach()

I see. What makes it complicated is that I cannot simply detach the producer of ‘x’, as it’s expected to receive gradient update from another downstream path.

another path won’t be affected, you’ll have separate graph nodes for each use of tensor, i.e.

x_ng = x.detach()
y1 = x_ng @ m1
y2 = x @ m2

two paths start from x, original x tensor is not effected by detach, and y1 is trainable only through m1

1 Like

Hey, I figured out how to do it with autograd.Function after reading through the pages. Let me know if it makes sense.

class TwoOpGradientController(torch.autograd.Function):
    @staticmethod
    def forward(ctx, *args):
        ctx.args_0_shape = args[0].shape
        return args[0] @ args[1]
    
    @staticmethod
    def backward(ctx, grad_output):
        lhs_zero_grad_input = torch.zeros_like(ctx.args_0_shape)
        rhs_grad_output = # do weird stuff to rhs gradients
        
        return lhs_zero_grad_input, rhs_grad_output

it can be correct but doesn’t make much sense to me, that you’re tying to matmul. you have to at least start from the original formula (IIRC, rhs_grad_output = grad_output @ args[0].t()) if you have a reason to have such a Function.

Yes, that’s a good point. The toy example is just to show how the syntax works, as the existing examples typically show simpler functions, such as ReLU and Exp.

BTW, I think the original formula is args[0].t() @ grad_output, such that the resulting shape is correct ( k by m * m by n → k by n).