I am trying to process an image in non-overlapping tiles (partitions), and then compute the loss on the entire image. What this entails is a forward pass of a batch of partitions, the recomposition of the image (i.e. put them together like a jigsaw puzzle), and then computing the loss. Now when I backpropagate, I imagine I need to “wire” the pixels from the partitions to their corresponding location in the full image. How can I do that? The operation is essentially just a copy plus pixel coordinate shifts. Does the dynamic graph handle this automatically?
This seems resolvable in a similar way to the data distributed module.
If you only use pytorch functions to construct your forward pass, I think the loss.backward() function should do the backprop for you automatically.
My concern is the copy. Let’s say my images are as below where each letter represents a pixel,
{|abc|, |def|, |ghi|, |jkl|} -> NET -> IMAGE_JIGSAW_CONCAT -->
|abc|def|
|ghi|jkl|
How do I perform what I am calling IMAGE_JIGSAW_CONCAT to shuffle and concatenate the outputs so that pixel a
in the output backpropagates to pixel a
in the input, and so on?
Will just using torch.cat
do the trick?
I’ve tried a simple example and it seems to me it should work.
import torch
import torch.nn as nn
# create 4 tensorts to simulate 4 images with size 2*2
a,b,c,d = torch.arange(16).reshape(-1,2).chunk(4)
# set them to require grads
a.requires_grad=True
b.requires_grad=True
c.requires_grad=True
d.requires_grad=True
# concat 4 imgs into 1 img with size 4*4
e = torch.cat((torch.cat((a,b),1), torch.cat((c,d),1)))
# simulate the forward pass with a liniear transformation
f = e.mm(torch.arange(8).reshape(4,2))
# calculate loss using cross entropy
loss = nn.CrossEntropyLoss()(f, torch.tensor([0,1,0,0]))
# backword loss
loss.backward()
# check grads of a w.r.t loss
print(a.grad)
1 Like