Calculate loss per sample in a batch + backward pass

Hi,

I am working on my first little more complex deep learning based image segmentation task where the loss function needs to be calculated per sample in a batch. The loss values are summed up to get a loss per batch. The loss function is not in a way that the batch input tensor [batch, channel, height, width] is reduced by “simple” tensor operations to a loss output tensor of size [batch], therefore this single sample calculation approach.
On github I found the implementation of the loss function in tensorflow and I converted it already to pytorch. The lines I don’t know how to convert are:

loss = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
loss, _, _, _, _ = tf.while_loop(cond, body, [loss, embedding, label_map, neighbor, 0])

loss = loss.stack()
loss = tf.reduce_mean(loss)

Its actually a while loop over the samples in the batch, calling the loss function “body” for each sample. I don’t know how to do this in pytorch with regard to the backward pass and gradient calculation.

Thanks a lot in advance.

I’m not deeply familiar with TF, but based on loss.stack() I assume the corresponding PyTorch code snippet might be:

losses = []
while condition:
    loss = calculate_loss(...)
    losses.append(loss)
losses = torch.stack(losses).mean()
losses.backward()

So easy (-: Thank you. I will try that out.