Taking weighted average of tensors

I have a network that spits out 5 tensors of equal dimensions. Each tensor represents a segmented output of the same image.

I want to also train part of the network to take the weighted average of these tensors. How would I take the average of these tensors? Would I used an nn.Conv2d after concatenating them?

The idea I have in my head is something like:
Output of network: t1, t2, t3, t4, t5
Then (x1 * t1 + x2 * t2 + x3 * t3 + x4 * t4 + x5 * t5) / (x1 + x2 + x3 + x4 + x5), where xi are the weights, and ti are the tensor outputs from the network.

Any advice appreciated.

What about something as simple as this:

tensor_shape = (3, 224, 224)  # shape of each tensor
five_tensors = torch.randn(5, *tensor_shape, requires_grad=True)
weights = torch.rand(5, requires_grad=True)
weighted_avg = (weights.view(5, 1, 1, 1) * five_tensors).sum(dim=0) / weights.sum()

# now to check that it does what we want, by looking at a random element
print((five_tensors[:, 0, 100, 100] * weights).sum() / weights.sum())
print(weighted_avg[0, 100, 100])

Output:
tensor(-0.5245, grad_fn=<DivBackward0>)
tensor(-0.5245, grad_fn=<SelectBackward>)
1 Like