# How to remove reduction in a register_hook

Hi All,

I have a bit of a technical question regarding the register_hook function. The question I want to ask is how can I remove the reduction that’s done within register_hook function. I’m not sure if this is the correct terminology so please correct me if I’m wrong.

For example, the code below takes an input, IN, of size 1 with a batch of size BATCH. It passes it through a weight matrix to create an output OUT. The Squared Error (not MSE) is then calculated with Tensor Y and the gradient is calculated for all inputs.

import torch

BATCH=10
IN=1
NHID=5
OUT=2

X = torch.randn(BATCH, IN)
W.register_hook(lambda x: print("W: ",x))
Y = torch.randn(BATCH, NHID)

OUT=X.mm(W)
loss = (OUT-Y)**2
loss.register_hook(lambda x: print("loss: ",x))
loss.backward(torch.ones_like(loss))

This prints out the following

loss:  tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
W:  tensor([[  4.9739,   0.1722,   0.9640, -12.8224,   5.5129]])

From my understanding, the register_hook function calculates the gradient of the loss w.r.t to the Tensor to which it is registered. However, the gradient vector that is shown is of the same dimensionality of W and does not hold the batch dimension. So, it must be summed over or reduced somehow. Is it possible to disable this feature somehow so the hook returns a Tensor of (BATCH, IN, HID) ?

Hi,

This hook just gives you the gradient that was computed for W. And this gradient has the same size as W.
Note that the hook is not doing any computation, it just shows you something that was computed before.

I am not sure to understand why you expect to see a Batch dimension on the gradient of W here. the torch.mm() operation takes two inputs and compute the gradient for each of them. In particular for W, it will be X.t() * grad_out. As you can see the Batch dimension “naturally” disappears when you compute the gradient.

Hi!

Thanks for the quick response! The reason I’m asking about this is because I want to get access to the gradient of the loss w.r.t W for all inputs in a batch. So, in effective get the gradient -> d(Loss)/d(W) for all inputs in a batch which would have a dimensionality along the lines of something like [hid_features, in_features, batch]

From my understanding, register_hook gives d(Loss)/d(Z) where Z is the output of a module (In my case nn.Linear) and has the dimensionality [hid_features, in_features]. On the other hand, register_backward_hook gives d(Z)/d(W) which gives grad_output which has the dimensionality of [batch, hid_features] corresponding to the gradient for each input against weight. My understanding mainly comes from this post here: Exact meaning of grad_input and grad_output (I assume the explanation within this post is correct?). The register_backward_hook has the batch dimension within it, and I was wondering if it were possible to get the gradient of the loss with respect to W for all values in the batch in a similar manner to what the register_backward_hook does, except for d(Loss)/d(W) rather than d(Z)/d(W).

Is this something that is possible with PyTorch?

Thank you for the help and clarification!

Hi,

The gradient from the output of the module match the output size and so have a batch dimension. But the weight don’t have a batch dimension so you don’t have one for the gradients either.

You can do two things here:

• Do as many backward as there are sample in the batch. Saving the gradient each time and stack them at the end to get your (batch, in, nhid) gradients.
• Expand the weights to `(batch, in, nhid) to have as many weights as you have samples. And in the forward, make sure each element in the weight is used to compute the output for each sample in your batch. Now when you will backprop, you will again get a gradient of the size of the weights that contain all the values you want.
1 Like

Thank you for your help! I shall give it a go!