 Zero'ing one input's gradient for Matrix Multiply

Hi! the @ operator, or matrix multiply, is stateless and accepts 2 input tensors. During the backprop, my understanding is that it’ll calculate two gradients w.r.t. the 2 input tensors, which will each update any variables via the chain rule along the paths that produce them, respectively.

In my experiment, I want to zero out the gradient for only one input tensor and keep the other as-is. My guess is that I need to use autograd, but I’m fairly new to this and the toy example (PyTorch: Defining New autograd Functions — PyTorch Tutorials 1.7.0 documentation) doesn’t seem suffice. Please help!

I figured I could use the tutorial to add a new custom ‘Identity’ layer that passes through activations in forward prop but zeros out the gradient in back prop. But I’m really curious about how to do it properly with bmm (@). Thanks.

It turns out that I also need to modify the gradient of the other operand. Is there any example or tutorial that shows how to use Autograd function with an operation that produces multiple gradients? TIA.

@staticmethod
def forward(ctx, input):
# pass the input feature map unmodified, as Identity
return input

@staticmethod

Then, instead of directly doing a @ b, I added self.fwd_identity = FwdOnlyIdentity.apply. Then in the forward function, the computation becomes self.fwd_identity(a) @ b. Does this make sense?

in the simplest case, you can just do something like x.detach() @ weights

1 Like

The matrix multiply here doesn’t have any trainable weights. It takes two activation tensors as input.

that’s arbitrary variable names, this cuts off the whole computation chain that produced “x”, if you return a gradient of zeros, zeros will [back] propagate further, so it is the same as blocking with detach()

I see. What makes it complicated is that I cannot simply detach the producer of ‘x’, as it’s expected to receive gradient update from another downstream path.

another path won’t be affected, you’ll have separate graph nodes for each use of tensor, i.e.

x_ng = x.detach()
y1 = x_ng @ m1
y2 = x @ m2

two paths start from x, original x tensor is not effected by detach, and y1 is trainable only through m1

1 Like

Hey, I figured out how to do it with autograd.Function after reading through the pages. Let me know if it makes sense.

@staticmethod
def forward(ctx, *args):
ctx.args_0_shape = args.shape
return args @ args

@staticmethod