@torch.compile(backend="eager", fullgraph=True)
def accum(k, new_k) -> torch.Tensor:
if k.shape[2] > 50:
k = k[:, :, 1:, :]
return torch.cat([k, new_k], axis=2)
k = torch.zeros((3, 3, 0, 3))
for i in range(1000):
x = torch.zeros((3, 3, 1, 3))
x[:] = i
k = accum(k, x)
This functional-style code will work (it determines that dimension 2 of k is changing).
but this
class Accum(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("k", torch.zeros(3, 3, 0, 3), persistent=False)
torch._dynamo.mark_dynamic(self.k, 2)
def forward(self, x):
self.k = accum(self.k, x)
return self.k
a = Accum()
a = torch.compile(a, backend="eager", fullgraph=True)
for i in range(1000):
x = torch.zeros((3, 3, 1, 3))
x[:] = i
k = a(x)
will recompile each time through the loop until failing from too many recompiles because its guarding / specializing for the size of k
tensor ‘L[‘self’]._buffers[‘k’]’ size mismatch at index 2. expected 0, actual 8
Is there a way I can make it behave like the functional style?