Combining no_grad() decorator and with torch.no_grad() operator causes gradients to be enabled

Hello, so I was wondering if this is the intended behaviour. Basically, when exitting a @torch.no_grad() decorated function, instead of returning to previous state of gradient enabling, it just turns them on. This causes my library to accumulate gradients during validation phase and getting OOM, instead of simply computing the results.

Here is a minimal example, with 4 experiments. First two have gradients disabled in the main loop (using with-block). Last two have gradients enabled all the way. The first case is what I consider to be buggy or at least weird. Is this intended?

import torch as tr

@tr.no_grad()
def f(a, b):
	res = a @ b
	print(" Inside (decorator):", res.requires_grad, end="")
	return res

def g(a, b):
	with tr.no_grad():
		res = a @ b
		print(" Inside (with):", res.requires_grad, end="")
	return res

def experiment1():
	a = tr.randn(10, 20).requires_grad_()
	b = tr.randn(20, 30).requires_grad_()
	print("Before all:", (a@b).requires_grad)

	
	with tr.no_grad():
		for i in range(5):
			print(i, "Before:", (a@b).requires_grad, end="")
			_ = f(a, b)
			print(" After:", (a@b).requires_grad)

def experiment2():
	a = tr.randn(10, 20).requires_grad_()
	b = tr.randn(20, 30).requires_grad_()
	print("Before all:", (a@b).requires_grad)
	
	with tr.no_grad():
		for i in range(5):
			print(i, "Before:", (a@b).requires_grad, end="")
			_ = g(a, b)
			print(" After:", (a@b).requires_grad)

def experiment3():
	a = tr.randn(10, 20).requires_grad_()
	b = tr.randn(20, 30).requires_grad_()
	print("Before all:", (a@b).requires_grad)
	
	for i in range(5):
		print(i, "Before:", (a@b).requires_grad, end="")
		_ = f(a, b)
		print(" After:", (a@b).requires_grad)

def experiment4():
	a = tr.randn(10, 20).requires_grad_()
	b = tr.randn(20, 30).requires_grad_()
	print("Before all:", (a@b).requires_grad)
	
	for i in range(5):
		print(i, "Before:", (a@b).requires_grad, end="")
		_ = g(a, b)
		print(" After:", (a@b).requires_grad)

experiment1()
experiment2()
experiment3()
experiment4()

With outputs:

Before all: True
0 Before: False Inside (decorator): False After: True
1 Before: True Inside (decorator): False After: True
2 Before: True Inside (decorator): False After: True
3 Before: True Inside (decorator): False After: True
4 Before: True Inside (decorator): False After: True
Before all: True
0 Before: False Inside (with): False After: False
1 Before: False Inside (with): False After: False
2 Before: False Inside (with): False After: False
3 Before: False Inside (with): False After: False
4 Before: False Inside (with): False After: False
Before all: True
0 Before: True Inside (decorator): False After: True
1 Before: True Inside (decorator): False After: True
2 Before: True Inside (decorator): False After: True
3 Before: True Inside (decorator): False After: True
4 Before: True Inside (decorator): False After: True
Before all: True
0 Before: True Inside (with): False After: True
1 Before: True Inside (with): False After: True
2 Before: True Inside (with): False After: True
3 Before: True Inside (with): False After: True
4 Before: True Inside (with): False After: True

As you can see, in the first experiment, in the first iteration inside the while-block, the gradients are disabled, then after leaving the decorated function first time, they become enabled (instead of returing to previous state, which is also disabled). Then, they are enabled in all the further iterations. However, in the second experiment (with block in the function as well), they are disabled during all iterations.

Hi, is there any update on this?

I ran it again on my end with pytorch 1.7.1 and above, and I didn’t get the bug:

Before all: True
0 Before: False Inside (decorator): False After: False
1 Before: False Inside (decorator): False After: False
2 Before: False Inside (decorator): False After: False
3 Before: False Inside (decorator): False After: False
4 Before: False Inside (decorator): False After: False
Before all: True
0 Before: False Inside (with): False After: False
1 Before: False Inside (with): False After: False
2 Before: False Inside (with): False After: False
3 Before: False Inside (with): False After: False
4 Before: False Inside (with): False After: False
Before all: True
0 Before: True Inside (decorator): False After: True
1 Before: True Inside (decorator): False After: True
2 Before: True Inside (decorator): False After: True
3 Before: True Inside (decorator): False After: True
4 Before: True Inside (decorator): False After: True
Before all: True
0 Before: True Inside (with): False After: True
1 Before: True Inside (with): False After: True
2 Before: True Inside (with): False After: True
3 Before: True Inside (with): False After: True
4 Before: True Inside (with): False After: True

This had been open for quite some time, but eventually was recorded as Issue 44531 and fixed shortly after in September 2020 / PyTorch 1.7.

Best regards

Thomas

1 Like