Function which outputs tensor by reference

I’m struggling to write a function ConvToToeplitzTr() which takes a conv2d’s weight matrix W as the input, constructs a new matrix using its entries (a Toeplitz matrix representation of the given conv2d, transposed) by reference and returns this new matrix M. Both W and M are then used to calculate the output of the layer.

My goal is for autograd to actually treat the variables in the original conv2d’s weight matrix W, as well as the variables in the output matrix of ConvToToeplitzTr() M as the same and correspondingly accumulate the gradients for each of these variables from both W and M.

The following minimal example already fails to return a tensor by reference:

import torch

def foo(x):
    return torch.stack([x,x])

y = torch.tensor(4)
z = foo(y)
print(z) # [4, 4]
y += 1
print(z) # [4, 4]. Want: [5,5]

Hi,

I think the notion of “reference” that you want here is what we call view: another Tensor that looks at the same memory.
If you want to do that, you will need to make that you only use functions that return views. You can find more details and the list of these functions here: Tensor Views — PyTorch 1.7.0 documentation

1 Like

Hello albanD,

thank you very much for your reply.

I get the idea, but my specific case seems to be somewhat more complicated.

Here’s my actual code, note esp. the assignment to m inside the loops before the return statement:

def convmatrix2d(kernel, image_shape):
    # kernel: (out_channels, in_channels, kernel_height, kernel_width, ...)
    # image: (in_channels, image_height, image_width, ...)
    assert image_shape[0] == kernel.shape[1]
    assert len(image_shape[1:]) == len(kernel.shape[2:])
    result_dims = torch.tensor(image_shape[1:]) - torch.tensor(kernel.shape[2:]) + 1
    m = torch.zeros((
        kernel.shape[0],
        *result_dims,
        *image_shape
    ))
    for i in range(m.shape[1]):
        for j in range(m.shape[2]):
            m[:,i,j,:,i:i+kernel.shape[2],j:j+kernel.shape[3]] = kernel.view(kernel.shape) # Previously: ... = kernel
    return m.flatten(0, len(kernel.shape[2:])).flatten(1)

# In forward() of my network:
self.convtest = nn.Conv2d(1, 1, kernel_size=3, padding=1)
output_toeplitz = convmatrix2d(self.convtest.weight, [1, 5, 5])

Even with the addition of .view(), output_toeplitz variable does not change if we now to manually change some values in self.convtest.weight.

If you want output_toeplitz to change when you update self.convtest.weight, then you have to make sure that you only use ops that are views to get from to the other.
Here you create a Tensor full of zeros that has brand new memory. So this one can’t be a view of kernel.

You might want to use the Unfold class we provide to do that to get a view properly: Unfold — PyTorch 1.7.0 documentation

1 Like

My overarching goal, born from a mathematical algorithm, is to calculate a matrix-vector product of the transposed Toeplitz matrix (associated to a convolution kernel) with a vector. It is important that gradients w.r.t. the kernel can be calculated afterwards.

The way I’m trying to go (maybe there’s a better way in PyTorch?) is: somehow build a Toeplitz matrix of a conv2d kernel in such a manner that it still works with autograd, then transpose it, then perform matrix-vector multiplication with it. With the code I’ve posted above I’m able to set up the Toeplitz matrix, however, it indeed duplicates the tensors, which breaks autograd.

I’ve taken a look at Fold/Unfold, thank you. However, I’m not sure how they can help me, since I’m trying to manipulate a convolution kernel. I’m not trying to manipulate an input image (in a block-based manner, which Unfold/Fold facilitate). Maybe I’m not seeing how to use Unfold to achieve my goal.

What would be a good way to do that in PyTorch?

I am not sure this is true.
You don’t use detach or no grad so the gradients will be properly tracked through your convmatrix2d function (even though the output matrix does not share its memory with the input, if you backprop, the gradients will flow back all the way to the input.

1 Like

Looks like you’re right! For output_toeplitz() defined as above:

convtest = nn.Conv2d(1, 1, kernel_size=3, padding=1)
output_toeplitz = convmatrix2d(convtest.weight, [1, 5, 5])

for _ in range(3):
    t = torch.matmul(output_toeplitz,torch.randn(25,1))
    print("grad after matmul with output_toeplitz:")
    print(convtest.weight.grad) # None in 1st loop, zeros(3,3) afterwards
    t.sum().backward(retain_graph = True)
    print("grad after .backward():")
    print(convtest.weight.grad) # Nonzero updated gradient
    convtest.zero_grad()

A quick side question: why doesn’t this code work without retain_graph = True in backward()? I get RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed.

This happens because you run the convmatrix2d only once for the 3 iterations of your loop. Which means that this part of the autograd graph is shared by all 3 iterations of the loop.
But if you don’t state retain_graph=True, the graph is freed during backward. And here, since you have a shared part in the graph, you don’t want it to be freed.

Also if you update the value of convtest.weight (via an optimizer for example), you want to recall convmatrix2d to make sure you reconstruct the matrix based on the new content of convtest.weight!

1 Like

Thank you, very insightful!