How does loss know what to do with batches? (does it broadcast?)

As an example, say I have a model which maps a single x of shape(100,) to a single y of some size as

y = model(x)

The loss is

loss = loss_fn(y,target).

Now, I feed the model a batch which is composed of a few vertically stacked x vectors

x_batch  = torch.cat((x_1,x_2,x_3),0)
#have not tested this.. trying to indicate a variable with shape (3,100)
y_batch = model(x_batch)

When I calculate the loss,

loss = loss_fn(y_batch,target_batch)

How is it that loss_fn knows to calculate the loss independently between each of the three vectors within the batch, instead of between the higher dimensional variable as a whole?

How does the loss function know when it’s dealing with a batch of outputs and targets, instead of one output and one target?

Struggling to articulate the question… I hope it makes sense!

Have read http://pytorch.org/docs/0.3.0/_modules/torch/nn/modules/loss.html a bit. maybe I’m asking: is there a major difference between reduce=True and reduce = False within a loss fn ?

1 Like

Hi,

All the loss functions in pytorch assume that the first dimension is the batch size.
For few of them like MSELoss, this is not relevant as it just compares every entry in the output and the target one by one and then add them up.
Remember that the loss function formula when you do SGD is: loss = sum_{i} loss(out_i, target_i) where i represent different samples.

The reduce function control whether the sum is done by the loss, meaning you will get a single value as output, or if it is not done, hence you will get a Tensor containing one entry for every element in your batch.

1 Like