This is a follow up my earlier post. Here is an example of the sort of thing I want to consider. Consider the scalar case, where
we have:
f(x) = w2 ( z1(w1 x+b1)))+b2
and
z1 = exp(w1 x+b1)
A naive implementation would be:
class ScalarNN1(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(1, 1)
self.layer2 = nn.Linear(1, 1)
def forward(self, x):
x1 = self.layer1(x)
z = torch.exp(x1)
y = self.layer2(z * x1)
return y
When I do compute gradients with respect to the parameters w1
and b1
(in layer1
), the chain rule will back propagate through z
. Alternatively, suppose I construct:
class ScalarNN2(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(1, 1)
self.layer2 = nn.Linear(1, 1)
def forward(self, x):
x1 = self.layer1(x)
with torch.no_grad():
z = torch.exp(x1) # this could also be done with detach
y = self.layer2(z * x1)
return y
Because of the torch.no_grad()
, there will be no back propagation through the z
variable.
I consider these two cases because I am considering training routines where different parts of the network are trained differentially, and I want to be able to only take gradients through certain intermediates. While I can check this example by hand, I was hoping to find a way of checking this more generally. I. had hoped to be able to use something graphical like torchview
, but that doesn’t seem to do it.
Any suggestions would be appreciated.