Problem with tensor hooks
Hello everyone,
I am trying to compute the variance of the gradient exploiting hooks in a ResNet-like architecture (just a toy network). Basically, I have a model that performs
out = b1(x) + (alpha * b2(x))
the variance (V) and mean (M) of the gradient (according to an experiment) are
[OUT] V = 4.25329e-10, M = 6.00018e-08
[B2] V = 4.25329e-10, M = 6.00018e-08
[B1] V = 4.25329e-10, M = 6.00018e-08
[IN] V = 1.70131e-09, M = 1.20004e-07
where B2 and B1 are the output of b1(x) and b2(2), x is sampled from a normal distribution and alpha is equal to one.
How is it possible that OUT variance and mean are equal to B2 and B1? Is it an issue related to my code or to the method register_hook? (same thing happens if you add a convolution after the ReLU inside the Sequential modules)
Thanks in advance for your help
Below you can find the code to reproduce the experiment
import torch
import numpy as np
import torch.nn as nn
alpha = 1
channels = 64
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# conv = torch.nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.b1 = nn.Sequential(torch.nn.ReLU())
# conv = torch.nn.Conv2d(channels, channels, kernel_size=1, padding=0, bias=False)
self.b2 = nn.Sequential(torch.nn.ReLU())
self.l = nn.Linear(channels, 10)
def forward(self, x):
x.register_hook(lambda g: print(f"[IN] V = {g.var():.5e}, M = {g.mean():.5e}"))
b1 = self.b1(x)
b1.register_hook(lambda g: print(f"[B1] V = {g.var():.5e}, M = {g.mean():.5e}"))
b2 = self.b2(x)
b2.register_hook(lambda g: print(f"[B2] V = {g.var():.5e}, M = {g.mean():.5e}"))
out = torch.add(b1, b2, alpha=alpha) # out = b1 + (alpha * b2)
out.register_hook(lambda g: print(f"[OUT] V = {g.var():.5e}, M = {g.mean():.5e}"))
out = nn.AdaptiveMaxPool2d(1)(out)
out = nn.Flatten()(out)
return self.l(out)
y = torch.randint(low=0, high=10, size=[128])
x = torch.normal(mean=0, std=1, size=[128, channels, 32, 32], requires_grad=True)
net = Net()
out = net(x)
loss = torch.nn.CrossEntropyLoss()(out, y)
loss.backward()