I’m using detach_
to cut off part of a retained graph:
import torch
from torch import nn
idx = 0
class M(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.tensor(2, dtype = torch.float32))
def forward(self, h, x):
global idx
new_h = h + x * self.w
def get_pr(idx_val):
def pr(*_): print("<-- {}".format(idx_val))
return pr
new_h.register_hook(get_pr(idx))
print("--> {}".format(idx))
idx += 1
return new_h
m = M()
z = torch.tensor([0], dtype = torch.float32)
a1 = torch.tensor([1], dtype = torch.float32)
a2 = torch.tensor([2], dtype = torch.float32)
b1 = torch.tensor([1], dtype = torch.float32)
b2 = torch.tensor([3], dtype = torch.float32)
b3 = torch.tensor([2], dtype = torch.float32)
c1 = torch.tensor([2], dtype = torch.float32)
c2 = torch.tensor([3], dtype = torch.float32)
h0 = torch.cat([z, z], dim = 0)
i0 = torch.cat([a1, b1], dim = 0)
h1 = m(h0, i0)
i1 = torch.cat([a2, b2], dim = 0)
h2 = m(h1, i1)
h2.backward(torch.tensor([3-h2[0],0]), retain_graph = True)
i2 = torch.cat([b3, c1], dim = 0)
h3 = m(torch.cat([h2[[1]], z], dim = 0), i2)
h3.backward(torch.tensor([6-h3[0],0]), retain_graph = True)
#h2.detach_()
i3 = torch.cat([c2], dim = 0)
h4 = m(torch.cat([h3[[1]]], dim = 0), i3)
h4.backward(torch.tensor([5-h4[0]]), retain_graph = True)
This prints -->
for forwards and <--
for backwards to see what’s going on. With no detach, the last few lines are (correct):
--> 3
<-- 3
<-- 2
<-- 1
<-- 0
If h3
is detached, it’s (also correct):
--> 3
<-- 3
If h1
or h2
are detached, it prints out the same line as for no detach (incorrect!). The correct output (e.g. for h2
) should be:
--> 3
<-- 3
<-- 2
I’m pretty sure this is a bug. But I’ve only been using pytorch for two days and don’t know the internals. Maybe I’m doing (or expecting) something wrong?