Negative loss for BC learning

I am pretty new to PyTorch and I am trying to implement BC learning for images as in this paper https://arxiv.org/pdf/1711.10284.pdf

What I’m basically trying to is to mix 2 images and their one-hot encoded labels with a random ratio and then learn the mixing ratio.

My model looks as follows:

class ConvNet(nn.Module):
    def __init__(self, n_classes):
        super(ConvNet, self).__init__()
        self.model = nn.Sequential(OrderedDict([
            ('conv1', ConvBNReLu(3, 32, 3, padding=1)),
            ('conv2', ConvBNReLu(32, 32, 3, padding=1)),
            ('max_pool1', nn.MaxPool2d(2, ceil_mode=True)),
            ('conv3', ConvBNReLu(32, 64, 3, padding=1)),
            ('conv4', ConvBNReLu(64, 64, 3, padding=1)),
            ('max_pool2', nn.MaxPool2d(2, ceil_mode=True)),
            ('conv5', ConvBNReLu(64, 128, 3, padding=1)),
            ('conv6', ConvBNReLu(128, 128, 3, padding=1)),
            ('conv7', ConvBNReLu(128, 128, 3, padding=1)),
            ('conv8', ConvBNReLu(128, 128, 3, padding=1)),
            ('max_pool3', nn.MaxPool2d(2, ceil_mode=True)),
            ('flatten', Flatten()),
            ('fc4', nn.Linear(in_features=128 * 4 * 4, out_features=512, bias=True)),
            ('relu5', nn.ReLU()),
            ('dropout5', nn.Dropout()),
            ('fc5', nn.Linear(512, 512)),
            ('relu6', nn.ReLU()),
            ('dropout6', nn.Dropout()),
            ('fc6', nn.Linear(512, n_classes)),
            ('softmax', nn.Softmax(dim=-1)) 
        ]))

    def forward(self, inp):
        return self.model(inp)

Then my learner like this:

criterion = nn.KLDivLoss()
learning_rate = 0.01
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1)
loss_values = []
PATH = './BCplus_epoch10.pth'

for epoch in range(10):  # Number of epochs (loops over dataset)
  epoch_loss = 0.0
  running_loss = 0.0
  input = []
  labels = []
  for i, data in enumerate(trainloader, 0):
      # get the inputs; data is a list of [inputs, labels]
      
      if i % 3 == 2: # batch size of 3 images to ConvNet
        input = torch.stack(input) # stack tensors
        labels = torch.stack(labels)

        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(input)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics every 200 loops
        epoch_loss += loss.item() * images_cifar[0][0].size(0)
        running_loss += loss.item()
        input = [] # clear input
        labels = [] # clear labels

      images_cifar, labels_cifar = data

      images_mix, labels_mix = mix(images_cifar[0], images_cifar[1], labels_cifar[0], labels_cifar[1], training,
                                    True)
      input.append(images_mix)
      labels.append(labels_mix)



      if i % 200 == 199:
          print('[%d, %5d] loss: %.3f' %
                (epoch + 1, i + 1, running_loss / 200))
          running_loss = 0.0
  loss_values.append(epoch_loss / 25000) # 25000 images downloaded
  print(epoch_loss)
  scheduler.step()

And lastly the mixing like this:

def preprocess(image, optplus, train):
    if optplus:
        normalizer = zero_mean
        mean = np.array([4.60, 2.24, -6.84])
        std = np.array([55.9, 53.7, 56.5])
    else:
        normalizer = normalize
        mean = np.array([125.3, 123.0, 113.9])
        std = np.array([63.0, 62.1, 66.7])
    if train:
        image = normalize(image, mean, std)
        image = horizontal_flip(image)
        image = padding(image, 4)
        image = random_crop(image, 32)

    else:
        image = normalize(image, mean, std)

    return image


def mix(image1, image2, label1, label2, optplus, train):
    image1 = tensor_to_numpy(image1)
    image2 = tensor_to_numpy(image2)
    image1 = preprocess(image1, optplus, train)
    image2 = preprocess(image2, optplus, train)
    image1 = torch.from_numpy(image1).float().to(device)
    image2 = torch.from_numpy(image2).float().to(device)
    label1 = label1.to(device)
    label2 = label1.to(device)

    # Mix two images
    r = torch.rand(1).to(device)
    if optplus:
        g1 = torch.std(image1).to(device)
        g2 = torch.std(image2).to(device)
        p = (1.0 / (1 + g1 / g2 * (1 - r) / r)).to(device)
        image = ((image1 * p + image2 * (1 - p)) / torch.sqrt(p ** 2 + (1 - p) ** 2)).to(device)
    else:
        image = (image1 * r + image2 * (1 - r)).to(device)

    # Mix two labels
    eye = torch.eye(nClasses).to(device)
    label = (eye[label1] * r + eye[label2] * (1 - r)).to(device)

    return image, label

I am trying to stack 3 images together and then pass them through the model. However when doing this I get a negative loss using KL divergence. I am using this loss function since it is what they use in the paper. Any idea what I am doing incorrectly here?

Based on the docs your output and target might be wrong:

As with NLLLoss , the input given is expected to contain log-probabilities and is not restricted to a 2D Tensor. The targets are given as probabilities (i.e. without taking the logarithm).

You should probably change the last nn.Softmax activation to nn.LogSoftmax and make sure the target values are a probability distribution.

Thank you for your answer.
I tried setting the Softmax to a LogSoftmax and the KL divergence error is indeed positive now. The target values look something like this

tensor([[0.000 0.000 0.149 0.000 0.000 0.000 0.851 0.000 0.000 0.000] ... (and then 2 more)]])

Which should be correct since they add up to 1, right? After chancing the Softmax to LogSoftmax however, the output values look something like this

tensor([[-2.129 -2.201 -2.1317 ....

Which is not a normal probability distribution anymore. Should anything be changed to the target values as well?

Another problem I had is in the network architecture, in the paper they use

('conv1', ConvBNReLu(3, 64, 3, padding=1)),('conv2', ConvBNReLu(64, 64, 3, padding=1)) ...

which gave me the error:

RuntimeError: Given groups=1, weight of size 32 32 3 3, expected input[3, 64, 32, 32] to have 32 channels, but got 64 channels instead

Which was solved by halving all out_channel values, however the authors of the paper do use a higher number of filters, looking at Table 7 in the paper. What am I doing wrong there?

These values are representing now log probabilities, so it should be fine.

Could you post the definition of ConvBNReLU, please?

I think the criterion and model seems to work now using a standard method now (so without using the mixing class), so thanks a lot for that! However when I do use the mixing function, the network doesn’t seem to learn anything at all.
I’m trying to follow the paper more closely so I took a batch size of 128, and mix them together to 64 images. This does not decrease the loss during training at all however.
The train function looks like this:

criterion = nn.KLDivLoss()
learning_rate = 0.1
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1)
loss_values = []
PATH = '/content/drive/My Drive/CS4240 - Deep Learning/Reproducibility project/Results/BCplus_epoch20.pth'

for epoch in range(20):  # Number of epochs (loops over dataset)
  epoch_loss = 0.0
  running_loss = 0.0
  inputs = []
  labels = []
  for i, data in enumerate(trainloader, 0):      

      images_cifar, labels_cifar = data
      inputs = [] # clear input
      labels = [] # clear labels
      for j in range(0,len(data),2):

        images_mix, labels_mix = mix(images_cifar[j], images_cifar[j+1], labels_cifar[j], labels_cifar[j+1], False,
                                    True)
        inputs.append(images_mix)
        labels.append(labels_mix)

      inputs = torch.stack(inputs) # stack tensors
      labels = torch.stack(labels)

      optimizer.zero_grad()

      # forward + backward + optimize
      outputs = net(inputs)
      loss = criterion(outputs, labels) # calculate loss
      optimizer.step() # update weights

      # print statistics every 200 loops
      running_loss += loss.item()

      if i % 200 == 199:
          print('[%d, %5d] loss: %.3f' %
                (epoch + 1, i + 1, running_loss / 200))
          running_loss = 0.0

  scheduler.step()

With mixing function (again)

def preprocess(image, optplus, train):
    if optplus:
        normalizer = zero_mean
        mean = np.array([4.60, 2.24, -6.84])
        std = np.array([55.9, 53.7, 56.5])
    else:
        normalizer = normalize
        mean = np.array([125.3, 123.0, 113.9])
        std = np.array([63.0, 62.1, 66.7])
    if train:
        image = normalize(image, mean, std)
        image = horizontal_flip(image)
        image = padding(image, 4)
        image = random_crop(image, 32)

    else:
        image = normalize(image, mean, std)

    return image


def mix(image1, image2, label1, label2, optplus, train):
    image1 = tensor_to_numpy(image1)
    image2 = tensor_to_numpy(image2)
    image1 = preprocess(image1, optplus, train)
    image2 = preprocess(image2, optplus, train)
    image1 = torch.from_numpy(image1).float()
    image2 = torch.from_numpy(image2).float()
    label1 = label1.to(device)
    label2 = label2.to(device)
    image1 = image1.to(device)
    image2 = image2.to(device)
    # Mix two images
    r = torch.rand(1).to(device)
    if optplus:
        g1 = torch.std(image1).to(device)
        g2 = torch.std(image2).to(device)
        p = (1.0 / (1 + g1 / g2 * (1 - r) / r)).to(device)
        image = ((image1 * p + image2 * (1 - p)) / torch.sqrt(p ** 2 + (1 - p) ** 2)).to(device)
    else:
        image = (image1 * r + image2 * (1 - r)).to(device)

    # Mix two labels
    eye = torch.eye(nClasses).to(device)
    label = (eye[label1] * r + eye[label2] * (1 - r)).to(device)

    return image, label

When training the network like this it does not reduce the loss at all, however when removing the mixing function from the trainer it does learn something, so something must be going wrong when mixing the images and labels. Is there something I am doing wrong in the mixing?

I’m not familiar with the paper, but I’m not sure how your preprocess method works exactly.

Assuming optplus uses the “mixing” function, the mean seems to differ a lot.
Also in the if train block, normalize is used, while normalizer doesn’t seem to be used at all.

Hmm you could be right that the preprocess method is not working correctly, but even when I remove the preprocessing and set r = 1 (mixing ratio), so that mixing and preprocessing does not occur. When I test images_cifar == images_mix it always return true, but then still my loss remains the same over each epoch. The mixing class now looks like this (with optplus=False):

def mix(image1, image2, label1, label2, optplus, train):
    image1 = image1.numpy()
    image2 = image2.numpy()
    label1 = label1.numpy()
    label2 = label2.numpy()
    # image1 = preprocess(image1, optplus, train)
    # image2 = preprocess(image2, optplus, train)
    # Mix two images
    #r = np.array(random.random())
    r = np.array([1])
    if optplus:
        g1 = np.std(image1)
        g2 = np.std(image2)
        p = 1.0 / (1 + g1 / g2 * (1 - r) / r)
        image = ((image1 * p + image2 * (1 - p)) / np.sqrt(p ** 2 + (1 - p) ** 2)).astype(np.float32)
    else:
        image = (image1 * r + image2 * (1 - r)).astype(np.float32)

        # Mix two labels
    eye = np.eye(nClasses)
    label = (eye[label1] * r + eye[label2] * (1 - r))

    image = torch.from_numpy(image).float().to(device)
    label = torch.from_numpy(label).float().to(device)

    return image, label

However when I do remove the whole mixing function the loss does seem to decrease, while to me it seems that I am still doing the same then (except that I’m skipping every second image)

EDIT: seems like loss.backward() was not in my loop anymore and len(data) should have been len(images_cifar). The learner does seem to learn something now without using the preprocess function. I will be trying to fix that now

My test error is still much higher than in the paper, I get 14% where they get 5%. I think there might be something wrong in my model. I am trying to implement this model:


As such:

class ConvBNReLu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super(ConvBNReLu, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=out_channels,
                      kernel_size=kernel_size,
                      stride=stride,
                      padding=padding,
                      bias=bias),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, input):
        return self.main(input)

class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x

class ConvNet(nn.Module):
    def __init__(self, n_classes):
        super(ConvNet, self).__init__()
        self.model = nn.Sequential(OrderedDict([
            ('conv1', ConvBNReLu(3, 64, 3, padding=1)),
            ('conv2', ConvBNReLu(64, 64, 3, padding=1)),
            ('max_pool1', nn.MaxPool2d(2, stride=2, ceil_mode=True)),
            ('conv3', ConvBNReLu(64, 128, 3, padding=1)),
            ('conv4', ConvBNReLu(128, 128, 3, padding=1)),
            ('max_pool2', nn.MaxPool2d(2, stride=2, ceil_mode=True)),
            ('conv5', ConvBNReLu(128, 256, 3, padding=1)),
            ('conv6', ConvBNReLu(256, 256, 3, padding=1)),
            ('conv7', ConvBNReLu(256, 256, 3, padding=1)),
            ('conv8', ConvBNReLu(256, 256, 3, padding=1)),
            ('max_pool3', nn.MaxPool2d(2, stride=2, ceil_mode=True)),
            ('flatten', Flatten()),
            ('fc4', nn.Linear(in_features=256 * 4 * 4, out_features=1024, bias=True)),
            ('relu5', nn.ReLU()),
            ('dropout5', nn.Dropout()),
            ('fc5', nn.Linear(1024, 1024)),
            ('relu6', nn.ReLU()),
            ('dropout6', nn.Dropout()),
            ('fc6', nn.Linear(1024, n_classes)),
            ('softmax', nn.LogSoftmax(dim=-1)) # Changed to LogSoftmax from Softmax, negative KL divergence otherwise
        ]))

    def forward(self, inp):
        return self.model(inp)

I am quite new to making model however this works, but not as good as I had hoped. Did I implement the model correctly?

I’m testing the model each epoch like this (no mixing):

criterion = nn.KLDivLoss()
learning_rate = 0.1 # Initial learning rate
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[100,150,200], gamma=0.1) # Decrease learning rate at milestones

test_error_standard = []
log_test_error = True
nEpoch = 250
for epoch in range(nEpoch):  # Number of epochs (loops over dataset)
  epoch_loss = 0.0
  running_loss = 0.0
  for i, data in enumerate(trainloader):
      # get the inputs; data is a list of [inputs, labels]
      
      label = []
      images_cifar, labels_cifar = data

      for j in range(len(labels_cifar)):
        label.append((torch.eye(10)[int(labels_cifar[j])]))
      labels = torch.stack(label).to(device)
      optimizer.zero_grad()
      # forward + backward + optimize
      outputs = net(images_cifar.to(device))
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      # print statistics every 200 loops

      running_loss += loss.item()

  
  print('[%d, %5d] loss: %.3f' %
        (epoch + 1, i, running_loss / i))
  running_loss = 0.0

  scheduler.step()
  if log_test_error: # Test network at each epoch and save test error
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_error = 100 - 100 * correct / total
    test_error_standard.append(test_error)
    print('Test error: %.2f %%' % (
        test_error))


print('Finished Training')

torch.save(net.state_dict(), PATH)

EDIT: The error was not having nn.ReLU() at the end of ConvBNReLu(nn.Module), getting similar validation errors now!