How to preserve autograd of tensor after .detach() and processing it?

Hello!

In the work that I’m doing, after the first conv2d() layer, the output is converted to numpy array to do some processing using .detach(). During this process, the new output will be 3 times bigger and then it is converted back to the tensor to be used as a input for the next conv2d() layer.

Is there anyway of getting the gradient back to the new tensor?

Note: The new tensor’s values are just copied from the old tensor.

Thank you guys very much!!!

Hi,

Could you deeply explain what are you doing?
What I’ve understood is that you have a tensor T, you convert it to numpy (T_numpy) but then the network keeps using T.
If that is the case you don’t need to struggle, you can keep using T as if nothing happened.

import torch

a=torch.rand(10).requires_grad_()

b=a.sqrt().mean()

c=b.detach()
b.backward()
print(b.grad_fn)
print(c.grad_fn)
<MeanBackward0 object at 0x7fba8eefdcc0>
None

In case you want to modify T according to what you have done in numpy, the easiest way is to reimplement that in pytorch.
Else, you can create a nn.Module which implements a proper backward.
The last thing you can do is to in-place modify values of T.

import torch

a=torch.rand(10).requires_grad_()

b=a.sqrt().mean()
print(b)
b.data=torch.tensor(5).float()
print(b)
c=b.detach()
b.backward()
print(b.grad_fn)
print(c.grad_fn)

Which would go against DL theory

Hi, thank you for replying!
In my work , I want to change the output of the conv2d layer before putting it through another conv2d layer. But I had to use a nested for loop for this process this output and it takes me too much time. So I have to convert it into numpy and turn the loop into numba to reduce the processing time. Then I can convert the T_numpy back to tensor and use this new tensor to continue in the next conv2d layer.
I saw in here that Mr @albanD stated that:

People not very familiar with requires_grad and cpu/gpu Tensors might go back and forth with numpy. For example doing pytorch → numpy → pytorch and backward on the last Tensor. This will backward without issue but not all the way to the first part of the code and won’t raise any error.

So is it possible to just leave it that way and continue training? Is it possible when I do that many times?

I would suggest to ask for help to optimize the nested loop rather than doing it. Realize you can write your custom c++ module (properly done) if you can’t really optimize that nested loop. You can “run” that but consider pytorch won’t be aware of all those operations, thus, will pass wrong gradients…

Lastly just some random links

And I know there is a library which creates cuda kernels to perform optimized operations for a given code, speeding it up, but i forgot the name :confused:

Thank you very much for the information!
I think I will open a new topic about optimizing the nested loop in pytorch version.

If you can recall the name of the mentioned library please inform me! I’m desperately short in time for my college graduation project. Any help would definitely mean a huge favor to me! Thank you!!!

I think you misunderstood what I meant there: the code will run without crashing. But the gradients won’t flow back all the way to the original pytorch code !
Using .detach() will prevent the gradients from flowing back.

Hi @albanD,
So if I just run it then the training will not be accurate right?
If so, is there anyway that I can attach the gradient of the original tensor to the new tensor converted from numpy array?
Thank you!

So if I just run it then the training will not be accurate right?

Well the gradients won’t be correct.

If so, is there anyway that I can attach the gradient of the original tensor to the new tensor converted from numpy array?

No. For the gradients to be properly computed you need to only use pytorch’s functions.
Note that you can add new elementary function where you define both the forward and backward if you don’t know how to write the forward using pytorch’s functions (see tutorial here). But that will be more complex as you need to implement the backward yourself.

Hi @albanD,
I’m trying to keep the gradient of the tensor in the way below, but the outcome is not as expected. Please kindly check it out if this has any problem that can be solved. Thank you!!!

class S2ConvNet_original(nn.Module):

    def __init__(self):
        super(S2ConvNet_original, self).__init__()

        f1 = 21
        f2 = 40

        self.conv1 = nn.Conv2d(1, f1, 3, stride=3, padding=0)
        self.upsample1 = nn.UpsamplingNearest2d(scale_factor=3)
        self.conv2 = nn.Conv2d(f1, f2, 3, stride=3, padding=0)
        self.upsample2 = nn.UpsamplingNearest2d(scale_factor=3)
        self.fc1 = nn.Linear(in_features=40*60*60, out_features=10)


    def forward(self, x, coor):
        x1 = self.upsample1(x)
        x = preprocess60x60(x, x1, coor)
        x = self.conv1(x)
        x = F.relu(x)
        x2 = self.upsample2(x)
        x = preprocess60x60_21chan(x, x2, coor)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.fc1(x.reshape(x.shape[0], -1))

        return x
"""
This part is for the processing functions
"""
@jit(nopython=False)
def change_patches(patches, image, coor):
    
    a = 0
    for i in range(60): #2160
        for j in range(60): #3840
            coordinate = [i,j]
            indices = np.matrix(coor[(coordinate[0],coordinate[1])])
            patches[a,0,2,1] = image[0,0,indices[1,0],indices[0,0]]
            patches[a,0,2,2] = image[0,0,indices[1,1],indices[0,1]]
            patches[a,0,1,2] = image[0,0,indices[1,2],indices[0,2]]
            patches[a,0,0,2] = image[0,0,indices[1,3],indices[0,3]]
            patches[a,0,0,1] = image[0,0,indices[1,4],indices[0,4]]
            patches[a,0,0,0] = image[0,0,indices[1,5],indices[0,5]]
            patches[a,0,1,0] = image[0,0,indices[1,6],indices[0,6]]
            patches[a,0,2,0] = image[0,0,indices[1,7],indices[0,7]]
            a += 1
    return(patches)

def preprocess60x60(images,newbatch, coor):
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    for batch in range(images.shape[0]):
        image = torch.zeros(1,1,60,60)
        image[0,:,:,:] = images[batch,:,:,:]
        
        kc, kh, kw = 1, 3, 3  # kernel size
        dc, dh, dw = 1, 1, 1  # stride
    
        pad = (1,1,1,1)
        paddedimage = F.pad(image, pad).to(DEVICE)        
    
        patches = paddedimage.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw).to(DEVICE)
        unfold_shape = patches.size()
        
        patches = patches.contiguous().view(-1, kc, kh, kw).to(DEVICE)
        numpypatches = patches.cpu().numpy()
        numpyimage = image.cpu().numpy()
        patches = change_patches(numpypatches, numpyimage, coor)
    
        # Reshape back
        patches_orig = torch.from_numpy(patches).view(unfold_shape)
        output_c = unfold_shape[1] * unfold_shape[4]
        output_h = unfold_shape[2] * unfold_shape[5]
        output_w = unfold_shape[3] * unfold_shape[6]
        patches_orig = patches_orig.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
        patches_orig = patches_orig.view(1, output_c, output_h, output_w)
        newbatch[batch,:,:,:] = patches_orig
    return newbatch

def preprocess60x60_21chan(images, newbatch, coor):
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    newimage = torch.zeros(1,images.shape[1],180,180).to(DEVICE)

    for batch in range(images.shape[0]):
        image = torch.zeros(1,images.shape[1],60,60)
        image[0,:,:,:] = images[batch,:,:,:]
        for channel in range(21):
            
            imagei = torch.zeros(1,1,60,60)
            imagei[0,0,:,:] = image[0,channel,:,:]

            kc, kh, kw = 1, 3, 3  # kernel size
            dc, dh, dw = 1, 1, 1  # stride
        
            pad = (1,1,1,1)
            paddedimage = F.pad(imagei, pad).to(DEVICE)        
        
            patches = paddedimage.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw).to(DEVICE)
            unfold_shape = patches.size()
            
            patches = patches.contiguous().view(-1, kc, kh, kw)
            
            numpypatches = patches.detach().cpu().numpy()
            numpyimage = imagei.detach().cpu().numpy()
            patches = change_patches(numpypatches, numpyimage, coor)
    
            # Reshape back
            patches_orig = torch.from_numpy(patches).view(unfold_shape)
            output_c = unfold_shape[1] * unfold_shape[4]
            output_h = unfold_shape[2] * unfold_shape[5]
            output_w = unfold_shape[3] * unfold_shape[6]
            patches_orig = patches_orig.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
            patches_orig = patches_orig.view(1, output_c, output_h, output_w)
            newimage[0,channel,:,:] = patches_orig
        
        newbatch[batch,:,:,:] = newimage
    return newbatch

"""
The main part for common training and testing
"""
def main(network):
    train_loader, test_loader, train_dataset, _ = load_data( MNIST_PATH, BATCH_SIZE)
    with open('dict60x60.pkl', 'rb') as fp:
        coor = pickle.load(fp)
    if network == 'original':
        classifier = S2ConvNet_original()
    elif network == 'deep':
        classifier = S2ConvNet_deep()
    else:
        raise ValueError('Unknown network architecture')
    classifier.to(DEVICE)

    print("#params", sum(x.numel() for x in classifier.parameters()))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(DEVICE)

    optimizer = torch.optim.Adam(
        classifier.parameters(),
        lr=LEARNING_RATE)


    for epoch in range(NUM_EPOCHS):
        for i, (images, labels) in enumerate(train_loader):
            classifier.train()

            images = images.to(DEVICE)
            #print(images.size())
            
            labels = labels.to(DEVICE)
        
            optimizer.zero_grad()
            outputs = classifier(images, coor)
            loss = criterion(outputs, labels)
            loss.backward()

            optimizer.step()

            print('\rEpoch [{0}/{1}], Iter [{2}/{3}] Loss: {4:.4f}'.format(
                epoch+1, NUM_EPOCHS, i+1, len(train_dataset)//BATCH_SIZE,
                loss.item()), end="")
        print("")
        correct = 0
        total = 0
        for images, labels in test_loader:

            classifier.eval()

            with torch.no_grad():
                images = images.to(DEVICE)
                labels = labels.to(DEVICE)

                outputs = classifier(images, coor)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).long().sum().item()

        print('Test Accuracy: {0}'.format(100 * correct / total))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--network",
                        help="network architecture to use",
                        default='original',
                        choices=['original', 'deep'])
    args = parser.parse_args()

    main(args.network)

You do convert a lot of thing to numpy inside your forward function (in preprocessXXX in particular). So these would prevent the gradients from flowing back.

So even if I try to have a tensor x1 and x2 with gradient and replace their values with new values from the preprocessXXX, it is still impossible for the gradients to flow back?

Do you have any suggestion for this case because I had to turn the tensor to numpy and put the loop into numba to save time. If there is anyway to speed up the nested loop ut preserve the gradient please letme know! I’m kinda desperate right now :cry: :cry: :cry:

If you hide from pytorch how the values are computed, it cannot compute the gradient.

I’m not sure what your change_patches function is doing but it looks like you should be able to implement it with regular indexing functions.
Could you give in detail with small example what you want from that function?

So now I have the input is a tensor of size (1,1,60,60). I pad the image by 1 for every side, divide the image into 3x3 patches by using unfold() with stride 1, process them and then stitch them back together into a new one using view(). Then the new image will be 3xtimes bigger than the input image.

For the process, I have a dictionary having the positions according to each position of the center pixel of the 3x3 patches. And then for each patch, I will replace it’s 8 outer pixels with the pixels from the input image that have the coordinates stored in the dictionary.

x = torch.randn(1, 1, 60, 60)

coor = {(0,0) : np.matrix([[59, 7, 14, 22, 29, 36, 44, 51], [1, 1, 1, 1, 1, 1, 1, 1]])
        ...
       }
kc, kh, kw = 1, 3, 3  # kernel size
dc, dh, dw = 1, 1, 1  # stride
pad = (1,1,1,1)
paddedimage = F.pad(x, pad).to(DEVICE)
patches = paddedimage.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
unfold_shape = patches.size()
print(unfold_shape)
patches = patches.contiguous().view(-1, kc, kh, kw)
print(patches.shape)

a = 0
for i in range(60):
    for j in range(60):
        coordinate = [i,j]
        indices = np.matrix(coor[(coordinate[0],coordinate[1])])
        patches[a,0,2,1] = image[0,0,indices[1,0],indices[0,0]]
        patches[a,0,2,2] = image[0,0,indices[1,1],indices[0,1]]
        patches[a,0,1,2] = image[0,0,indices[1,2],indices[0,2]]
        patches[a,0,0,2] = image[0,0,indices[1,3],indices[0,3]]
        patches[a,0,0,1] = image[0,0,indices[1,4],indices[0,4]]
        patches[a,0,0,0] = image[0,0,indices[1,5],indices[0,5]]
        patches[a,0,1,0] = image[0,0,indices[1,6],indices[0,6]]
        patches[a,0,2,0] = image[0,0,indices[1,7],indices[0,7]]
        a += 1
# Reshape back
patches_orig = patches.view(unfold_shape)
output_c = unfold_shape[1] * unfold_shape[4]
output_h = unfold_shape[2] * unfold_shape[5]
output_w = unfold_shape[3] * unfold_shape[6]
patches_orig = patches_orig.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
patches_orig = patches_orig.view(1, output_c, output_h, output_w)

After this process, the output can become the input for the next convolutional layer…
What I want is to apply this process after every conv2d() layer which means the example with 1 channel will extend to cases with many channels input.
Is there anyway to optimize this process? Thank you Mr. @albanD!!!

Hi,

I’m confused about what a is doing here.
The unfold shape is [1, 1, 60, 60, 1, 3, 3] right?
So if you increment a every time, it is going to go out of bound very quickly no?

I’m very sorry for the messed up code, I got mistaken :sweat_smile:
I editted the code. Please check it again!
I will be very grateful if you can help me with this!

Right.
That should be significantly faster already.

Using fancier function and playing with views, you might be able to narrow it down to a single call. But that shouldn’t be necessary :slight_smile:

import torch
from torch.nn import functional as F
import numpy as np

image = torch.randn(1, 1, 60, 60)
coor = torch.LongTensor(60, 60, 2, 8)
coor.fill_(0) # Just to get valid indices

kc, kh, kw = 1, 3, 3  # kernel size
dc, dh, dw = 1, 1, 1  # stride
pad = (1,1,1,1)

def original(image):

    paddedimage = F.pad(image, pad)
    patches = paddedimage.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
    unfold_shape = patches.size()
    patches = patches.contiguous().view(-1, kc, kh, kw)

    a = 0
    for i in range(60):
        for j in range(60):
            coordinate = [i,j]
            indices = np.matrix(coor[(coordinate[0],coordinate[1])])
            patches[a,0,2,1] = image[0,0,indices[1,0],indices[0,0]]
            patches[a,0,2,2] = image[0,0,indices[1,1],indices[0,1]]
            patches[a,0,1,2] = image[0,0,indices[1,2],indices[0,2]]
            patches[a,0,0,2] = image[0,0,indices[1,3],indices[0,3]]
            patches[a,0,0,1] = image[0,0,indices[1,4],indices[0,4]]
            patches[a,0,0,0] = image[0,0,indices[1,5],indices[0,5]]
            patches[a,0,1,0] = image[0,0,indices[1,6],indices[0,6]]
            patches[a,0,2,0] = image[0,0,indices[1,7],indices[0,7]]
            a += 1
    # Reshape back
    patches_orig = patches.view(unfold_shape)
    output_c = unfold_shape[1] * unfold_shape[4]
    output_h = unfold_shape[2] * unfold_shape[5]
    output_w = unfold_shape[3] * unfold_shape[6]
    patches_orig = patches_orig.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
    patches_orig = patches_orig.view(1, output_c, output_h, output_w)

    return patches_orig

def new(image):
    paddedimage = F.pad(image, pad)
    patches = paddedimage.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
    unfold_shape = patches.size()
    patches = patches.contiguous()

    coor_in = coor.unsqueeze(0).unsqueeze(0)
    patches[0, 0, :, :,0,2,1] = image[0,0,coor_in[0, 0, :, :, 1,0],coor_in[0, 0, :,:,0,0]]
    patches[0, 0, :, :,0,2,2] = image[0,0,coor_in[0, 0, :, :, 1,1],coor_in[0, 0, :,:,0,1]]
    patches[0, 0, :, :,0,1,2] = image[0,0,coor_in[0, 0, :, :, 1,2],coor_in[0, 0, :,:,0,2]]
    patches[0, 0, :, :,0,0,2] = image[0,0,coor_in[0, 0, :, :, 1,3],coor_in[0, 0, :,:,0,3]]
    patches[0, 0, :, :,0,0,1] = image[0,0,coor_in[0, 0, :, :, 1,4],coor_in[0, 0, :,:,0,4]]
    patches[0, 0, :, :,0,0,0] = image[0,0,coor_in[0, 0, :, :, 1,5],coor_in[0, 0, :,:,0,5]]
    patches[0, 0, :, :,0,1,0] = image[0,0,coor_in[0, 0, :, :, 1,6],coor_in[0, 0, :,:,0,6]]
    patches[0, 0, :, :,0,2,0] = image[0,0,coor_in[0, 0, :, :, 1,7],coor_in[0, 0, :,:,0,7]]

    # Reshape back
    patches_orig = patches.view(unfold_shape)
    output_c = unfold_shape[1] * unfold_shape[4]
    output_h = unfold_shape[2] * unfold_shape[5]
    output_w = unfold_shape[3] * unfold_shape[6]
    patches_orig = patches_orig.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
    patches_orig = patches_orig.view(1, output_c, output_h, output_w)

    return patches_orig


print((original(image.clone()) - new(image.clone())).abs().max())
1 Like

Thank you so much Mr. @albanD!
The code works well in a significantly faster speed!
I have applied it into my training code. But the accuracy of the model seems still as low as before. The reason I think that the operations of replacing pixels by pixels may mess up the gradient of the input.
Suppose that I apply this function to an output of a conv2d() layer, can there be a way for the output to have a decent gradient that can be backwarded to the beginning?

Hi,

No the gradients are properly computed.
You can check this by running:


from torch.autograd import gradcheck
gradcheck(lambda x: new(x).sum(), image.clone().detach().double().requires_grad_())

It checks that the autograd gradients match the ones computed with finite difference.

1 Like

Dear Mr. @albanD,
The solution that you showed me truely saved me from wasting a lot of time in my capstone project. I can’t express how grateful I am to your help and support!
I wish you all the best for your career and hope if I encounter any problem in the future, I can come back and receive valuable information and suggestions from the pytorch community!

1 Like

I need to use detach during conversion from tensor to numpy for kernel density estimation function but it want back propagate gradient. Is there any workaround solution?