I roughly have the following:
def make_autograd_function():
cache = {}
class TestFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q):
if 'foo' not in cache:
cache['foo'] = torch.zeros((5, 5), device=q.device, dtype=q.dtype)
I’m caching a tensor inside of the cache
variable. Is this safe?
(I know there’s ctx
, but I want to avoid re-initializing the tensor every time.)