MatMul for Slice behaves inconsistently on Tensors of different sizes, causing calculation graph damage

When I execute the following code, there is no error:

A = torch.eye(3, 3, requires_grad=True)
y = torch.randn(5, 2, 3, 3)

y[:, 1, :, :] = y[:, 0, :, :] @ A
y.sum().backward()

But if you make a slight modification: change the size of the first dimension from 5 to 1, an error occurs.

A = torch.eye(3, 3, requires_grad=True)
y = torch.randn(1, 2, 3, 3)

y[:, 1, :, :] = y[:, 0, :, :] @ A
y.mean().backward()

The error is:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3, 3]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

My current version is ‘2.0.0+cu118’; I also tested ‘1.13.0+cu117’ the error still exists. (I didn’t test other versions, so not sure about other versions)

It seems that when the first dimension is 1, the in-place operation will be automatically performed when the slice is assigned, which will cause the calculation graph to be destroyed.

A simple solution is to manually clone(), as follows:

A = torch.eye(3, 3, requires_grad=True)
y = torch.randn(1, 2, 3, 3)

y[:, 1, :, :] = y[:, 0, :, :].clone() @ A
y.mean().backward()

I’m wondering, are there any more details about Slice assignment mentioned in the documentation?
Is this situation I encountered intentional for performance optimization, or is it a BUG?

This is expected behavior. The original tensor needs to be saved for backward, so if you slice assign into it, when that tensor is used for gradient computation later it wouldn’t be computing the right thing. The error is there to prevent that.

Hi Soulitzer!

But why does it work when the first dimension of the y tensor has size 5, but
not when it has size 1? I don’t have any good theories as to why the two cases
behave differently.

Best.

K. Frank

First of all, I know that the calculation graph needs to retain the original tensor for backward, but here y[:, 1, :, :] = y[:, 0, :, :] @ A, the first two slices for assignment The two dimensions are different (one is 1, and the other is 0), this should not destroy the original tensor.

Second, this doesn’t explain why it doesn’t generate an error when the size of the first dimension is not 1.

I guess this may be due to the extra clone introduced by the slice view when the first dimension is not 1, or for performance, it will automatically perform in-place calculations when there are dimensions that can be squeezed? But I didn’t find related documentation.

PS: Is there any documentation on slice specifics? I searched for “pytorch slice” and “slice assign” but couldn’t find the corresponding documentation.

@KFrank @betacat2048

Whoops didn’t read the post fully the first time - The difference is probably because in the “first dimension is not 1” because matmul needs to do some extra reshaping to preprocess the inputs. And depending on the shape/strides of the tensor to reshape “reshape” may or may not call clone underneath.

I think your guess is right, the problem seems to be related to MatMul, because if you change @ to +, no Error will be raise.
It seems that MatMul will have different behaviors for different Slices, but it doesn’t seem to be mentioned in the documentation (torch.matmul). Is there any other documentation that explains the error?

Doing + doesn’t raise because to compute its derivative does not require saving the inputs. What is saved as input/output is considered implementation detail, so not documented unfortunately.