Backpropagate across corresponding locations

I am trying to process an image in non-overlapping tiles (partitions), and then compute the loss on the entire image. What this entails is a forward pass of a batch of partitions, the recomposition of the image (i.e. put them together like a jigsaw puzzle), and then computing the loss. Now when I backpropagate, I imagine I need to “wire” the pixels from the partitions to their corresponding location in the full image. How can I do that? The operation is essentially just a copy plus pixel coordinate shifts. Does the dynamic graph handle this automatically?

This seems resolvable in a similar way to the data distributed module.

If you only use pytorch functions to construct your forward pass, I think the loss.backward() function should do the backprop for you automatically.

My concern is the copy. Let’s say my images are as below where each letter represents a pixel,

{|abc|, |def|, |ghi|, |jkl|} -> NET -> IMAGE_JIGSAW_CONCAT -->

|abc|def|
|ghi|jkl|

How do I perform what I am calling IMAGE_JIGSAW_CONCAT to shuffle and concatenate the outputs so that pixel a in the output backpropagates to pixel a in the input, and so on?

Will just using torch.cat do the trick?

I’ve tried a simple example and it seems to me it should work.

import torch
import torch.nn as nn

# create 4 tensorts to simulate 4 images with size 2*2
a,b,c,d = torch.arange(16).reshape(-1,2).chunk(4)

# set them to require grads
a.requires_grad=True
b.requires_grad=True
c.requires_grad=True
d.requires_grad=True

# concat 4 imgs into 1 img with size 4*4
e = torch.cat((torch.cat((a,b),1), torch.cat((c,d),1)))

# simulate the forward pass with a liniear transformation
f = e.mm(torch.arange(8).reshape(4,2))

# calculate loss using cross entropy
loss = nn.CrossEntropyLoss()(f, torch.tensor([0,1,0,0]))

# backword loss
loss.backward()

# check grads of a w.r.t loss
print(a.grad)
1 Like