Hi @ptrblck
Just forget the Out.backward()
command and instead understand my actual need please.
Let me try to explain again in two parts.
Assume that I am using a batch of image from MNIST test loader and batch size is 64.
For this batch, I feed this batch image tensor( size 64x1x28x28) to my CNN model and get an output of size 64x10.
And I also apply some operation on my output tensor and get a final output tensor of shape 64x1.
That means for every input image tensor of shape 1x1x28x28 I get a scalar output of shape 1.
Is it clear until this part? If yes I continue with the second part.
Now, the thing I need to implement is I want to apply a backward derivative operation on my final output tensor. Actually what I want is to find the derivative of first output scalar in the final output tensor with respect to first input image in the input batch tensor. That is the case for all 64 output scalars and 64 input images.
In the example piece of code I provide, it is just used as a demonstration, please don’t be stuck with that. The backward operation there should be applied to every element of final output tensor one by one because for example first output is calculated based on the first input image and x.
and all the final derivatives should be stored in a tensor of shape 64x1x28x28.
(It is composed of sth like: d Out_1/d x_1 , d Out_2/d x_2 , …, d Out_64/d x_64)
If the batch size was 1 I could do it without a problem. But when the batch size is 64 for example, I can’t do it. Can you please help me implement in Pytorch?
Is my problem clear for you now?
May be I should use 64 different ‘x’ tensors of shape 1x28x28 because the backward operation will be based on each one of these tensors… I don’t know…
def backward_batch(model, image):
# Image is of shape 64x1x28x28
x = torch.zeros_like(image, requires_grad=True) # x is of shape 64x1x28x28
output = model(image + x)
output = softmax(output)
Out = some_func(output) #Final output is of shape 64x1
Out.backward() # How to handle here?
return delta.grad # should be 64x1x28x28 ```