Strange behavior with SGD momentum training

I’m transferring a Caffe network into PyTorch. However, when I’m training the network with exactly same protocol, the training loss behaves like this:

The loss increasing within each epoch and decreases when starting a new epoch. Thus forms this sawtooth-shaped loss.

Two problems:

  1. The increasing of loss within each epoch seems to be the problem with momentum. I set the momentum to 0 (originally 0.9 by Caffe protocol) and the shape goes away. Is there any difference between Caffe and PyTorch momentum setting?
  2. Let’s assume the problem is with momentum, the loss should not decrease at the start of each epoch, i.e. if the momentum is too large, the loss will always increase. Is there a hidden cleanup operation at the start of each epoch?

Here’s my code:

net.train()
for epoch in range(1, args.epochs + 1):
    net.train_step(epoch, log_interval=10)

and in class Net:

def train_step(self, epoch, log_interval=100):
    for batch_idx, (data, target) in enumerate(self.train_loader):
        if self.is_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        self.optimizer.zero_grad()
        loss = self.forward(data, y=target)
        loss.backward()
        self.optimizer.step()
        if batch_idx % log_interval == 0:
            print(
                Train Epoch: {} [{}/{} ({:.0f}%)]\t'.format(epoch, batch_idx * len(data),
                                                             len(self.train_loader.dataset),
                                                             100. * batch_idx / len(self.train_loader))
                + '\t'.join(
                    ['{}: {:.6f}'.format(key, self.loss.loss_value[key].data[0]) for key in
                     self.loss.loss_value.keys()]))

Thanks!

i dont think this is because of momentum. It is probably because of the way new samples are selected in the dataset.

PyTorch selects samples from dataset without replacement.
Which means, at the beginning of a new epoch, it is likely that you saw a sample in training set that you saw at the end of the last epoch. But over the epoch, you will never see a repeated sample.

Caffe probably samples with replacement, which means it is equally likely to see the same sample at any part of the epoch.

To verify this theory, you can write a with-replacement sampler, and see if that removes the sawtooth-shape from the loss:

class WithReplacementRandomSampler(Sampler):
    """Samples elements randomly, with replacement.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        # generate samples of `len(data_source)` that are of value from `0` to `len(data_source)-1`
        samples = torch.LongTensor(len(self.data_source))
        samples.random_(0, len(self.data_source))
        return iter(samples)

    def __len__(self):
        return len(self.data_source)

# then change the constructor of train_loader this way
self.train_loader = torch.utils.data.Dataloader(dataset, ..., sampler=WithReplacementRandomSampler(dataset), shuffle=False)
2 Likes

I have observed the same loss pattern on an automatic speech recognition task using the default PyTorch dataloader (sampling without replacement). Applying instead the with-replacement random sampler suggested by @smth removes the pattern, as would be expected.

I am aware of results in the literature indicating that sampling without replacement may yield faster convergence on some problems (https://arxiv.org/abs/1603.00570, https://arxiv.org/abs/1202.4184v1). However, I feel that the question remains as to why we would be okay with the loss pattern induced by sampling without replacement?

Drops at the interface between two epochs could be expected simply due to seeing samples which have already been fitted again for the first (or second or third or …) time. This could indicate that we are indeed learning to better classify that specific sample. However, the reason for the increasing trend in the loss during a single epoch evades me.

One line of thought is that fitting to the examples in the first batch of an epoch worsens the model performance on any other examples in the training set; and so forth for all batches of an epoch. However, wouldn’t this be expected to result in an increasing validation loss? Aren’t we in a sense overfitting to each training example during the epoch?

Thanks in advance.