Bug with `detach_`?

I’m using detach_ to cut off part of a retained graph:

import torch
from torch import nn

idx = 0
class M(nn.Module):
	def __init__(self):
		self.w = nn.Parameter(torch.tensor(2, dtype = torch.float32))
	def forward(self, h, x):
		global idx
		new_h = h + x * self.w
		def get_pr(idx_val):
			def pr(*_): print("<-- {}".format(idx_val))
			return pr
		print("--> {}".format(idx))
		idx += 1
		return new_h

m = M()
z = torch.tensor([0], dtype = torch.float32)
a1 = torch.tensor([1], dtype = torch.float32)
a2 = torch.tensor([2], dtype = torch.float32)
b1 = torch.tensor([1], dtype = torch.float32)
b2 = torch.tensor([3], dtype = torch.float32)
b3 = torch.tensor([2], dtype = torch.float32)
c1 = torch.tensor([2], dtype = torch.float32)
c2 = torch.tensor([3], dtype = torch.float32)

h0 = torch.cat([z, z], dim = 0)
i0 = torch.cat([a1, b1], dim = 0)
h1 = m(h0, i0)
i1 = torch.cat([a2, b2], dim = 0)
h2 = m(h1, i1)
h2.backward(torch.tensor([3-h2[0],0]), retain_graph = True)

i2 = torch.cat([b3, c1], dim = 0)
h3 = m(torch.cat([h2[[1]], z], dim = 0), i2)
h3.backward(torch.tensor([6-h3[0],0]), retain_graph = True)

i3 = torch.cat([c2], dim = 0)
h4 = m(torch.cat([h3[[1]]], dim = 0), i3)
h4.backward(torch.tensor([5-h4[0]]), retain_graph = True)

This prints --> for forwards and <-- for backwards to see what’s going on. With no detach, the last few lines are (correct):

--> 3
<-- 3
<-- 2
<-- 1
<-- 0

If h3 is detached, it’s (also correct):

--> 3
<-- 3

If h1 or h2 are detached, it prints out the same line as for no detach (incorrect!). The correct output (e.g. for h2) should be:

--> 3
<-- 3
<-- 2

I’m pretty sure this is a bug. But I’ve only been using pytorch for two days and don’t know the internals. Maybe I’m doing (or expecting) something wrong?


Keep in mind that detach() or detach_() won’t modify an existing graph. It will stop tracking operations for the current Tensor for every new operations you’re going to perform on them.
I am not sure to understand 100% your code but it looks like you’re trying to modify an existing graph no?

Yeah, trying to modify the existing graph. At the detach point, I know all gradients that flow backward through h2 will be zero, so I want to prune it from the graph.

You will need to detach the h2 variable before using it in the rest of the computations.

Detaching h2 earlier (e.g. right before i2) means gradients from h3.backward won’t flow through it, which I do want.

To make things less mysterious, this is a dynamic batching POC for RNNs: batch size of 2 (initially), and the three sequences are [a1, a2], [b1, b2, b3], [c1, c2] batched like this (each column is a timestep, first row is t labels):


Sequence a gets backpropped at t = 2 (h2.backward), but graph can’t be pruned for t < 2 because b isn’t complete so non-zero gradients still need to flow there. b gets backpropped at t = 3, so now the graph can be pruned.

From what I’ve seen so far it looks like it’s not possible to modify an existing graph. Is there any way of creating a new graph, without redoing the forward operations?

I’m afraid you cannot modify the graph like that at the moment.
You have to redo the forward.

@smth is that a feature we would like to add in the future? Not sure if it’s even feasible with the current backend.

You are right. Graph is construct at forward time.

While I see the use, I’m a bit afraid that it would end up being dangerous. How the graph is constructed is really an implementation detail, and graph destruction might have surprising consequences, that affect variables other than the one you detached.