Implementation of Caffe Code in PyTorch - SubOptimal solution

Dear All,

I’m trying to re-implement this code goturn in PyTorch.
Though I have tried to implement the PyTorch code just as is Caffe code with minor changes. I couldn’t figure out why I see poor convergence in PyTorch.

I would really appreciate any help if you could share what might be going wrong or if I’m missing.

I apologize for the long description.

Just to clear out what I have done compared to Caffe implementation. I will describe in detail.

Network architecture and weights

I have defined a CaffeNet architecture and weights just as in network.prototxt

class CaffeNetArch(nn.Module):

    """Docstring for AlexNet. """

    def __init__(self, num_classes=1000):
        """This defines the caffe version of alexnet"""
        super(CaffeNetArch, self).__init__()

        self.features = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
                                      nn.ReLU(inplace=True),
                                      nn.MaxPool2d(kernel_size=3, stride=2),
                                      nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75),
                                      # conv 2
                                      nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2),
                                      nn.ReLU(inplace=True),
                                      nn.MaxPool2d(kernel_size=3, stride=2),
                                      nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75),
                                      # conv 3
                                      nn.Conv2d(256, 384, kernel_size=3, padding=1),
                                      nn.ReLU(inplace=True),
                                      # conv 4
                                      nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2),
                                      nn.ReLU(inplace=True),
                                      # conv 5
                                      nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2),
                                      nn.ReLU(inplace=True),
                                      nn.MaxPool2d(kernel_size=3, stride=2))


def transfer_weights(model, pretrained_model_path, dbg=False):
    weights_bias = np.load(pretrained_model_path, allow_pickle=True,
                           encoding='latin1').item()
    layer_num = 0
    with torch.no_grad():
        for layer in model.modules():
            if type(layer) == torch.nn.modules.conv.Conv2d:
                layer_num = layer_num + 1
                key = 'conv{}'.format(layer_num)
                w, b = weights_bias[key][0], weights_bias[key][1]
                layer.weight.copy_(torch.from_numpy(w).float())
                layer.bias.copy_(torch.from_numpy(b).float())

    if dbg:
        layer_num = 0
        for layer in model.modules():
            if type(layer) == torch.nn.modules.conv.Conv2d:
                layer_num = layer_num + 1
                key = 'conv{}'.format(layer_num)
                w, b = weights_bias[key][0], weights_bias[key][1]
                assert (layer.weight.detach().numpy() == w).all()
                assert (layer.bias.detach().numpy() == b).all()


def CaffeNet(pretrained_model_path=None):
    """Alexenet pretrained model
    @pretrained_model_path: pretrained model path for initialization
    """

    model = CaffeNetArch().features
    if pretrained_model_path:
        transfer_weights(model, pretrained_model_path)

    return model

I use the same weights as in Caffe and transfer the weights and the rest of the FC layers are initialized just with the same weights as it is done in Caffe. ( I have dumped the initialized weights and assigned it here)

    def __init__(self, pretrained_model=None,
                 init_fc='/home/nthere/2020/pytorch-goturn/src/scripts/fc_init.npy', num_output=4):
        """ """
        super(GoturnNetwork, self).__init__()

        # self._net = AlexNet(pretrained_model_path=pretrained_model)
        self._net = CaffeNet(pretrained_model_path=pretrained_model)
        self._classifier = nn.Sequential(nn.Linear(256 * 6 * 6 * 2, 4096),
                                         nn.ReLU(inplace=True),
                                         nn.Dropout(0.5),
                                         nn.Linear(4096, 4096),
                                         nn.ReLU(inplace=True),
                                         nn.Dropout(0.5),
                                         nn.Linear(4096, 4096),
                                         nn.ReLU(inplace=True),
                                         nn.Dropout(0.5),
                                         nn.Linear(4096, num_output))

        self._num_output = num_output
        if init_fc:
            self._init_fc = init_fc
            self._caffe_fc_init()
        else:
            self.__init_weights()

    def _caffe_fc_init(self):
        """Init from caffe normal_
        """
        wb = np.load(self._init_fc, allow_pickle=True).item()

        layer_num = 0
        with torch.no_grad():
            for layer in self._classifier.modules():
                if isinstance(layer, nn.Linear):
                    layer_num = layer_num + 1
                    key_w = 'fc{}_w'.format(layer_num)
                    key_b = 'fc{}_b'.format(layer_num)
                    w, b = wb[key_w], wb[key_b]
                    w = np.reshape(w, (w.shape[1], w.shape[0]))
                    b = np.squeeze(np.reshape(b, (b.shape[1],
                                                  b.shape[0])))
                    layer.weight.copy_(torch.from_numpy(w).float())
                    layer.bias.copy_(torch.from_numpy(b).float())

Network initialization is exactly as in Caffe. Then the learning rate for each of the layers and the weight decay for frozen layers are the same as in Caffe. This is ensured in this code below:

    def __set_lr(self):
        '''set learning rate for classifier layer'''
        param_dict = []
        if 1:
            conv_layers = self._model._net
            for layer in conv_layers.modules():
                if type(layer) == torch.nn.modules.conv.Conv2d:
                    param_dict.append({'params': layer.weight,
                                       'lr': 0,
                                       'weight_decay': self.hparams.wd})
                    param_dict.append({'params': layer.bias,
                                       'lr': 0,
                                       'weight_decay': 0})

            regression_layer = self._model._classifier
            for layer in regression_layer.modules():
                if type(layer) == torch.nn.modules.linear.Linear:
                    param_dict.append({'params': layer.weight,
                                       'lr': 10 * self.hparams.lr,
                                       'weight_decay': self.hparams.wd})
                    param_dict.append({'params': layer.bias,
                                       'lr': 20 * self.hparams.lr,
                                       'weight_decay': 0})

In terms of the loss function, Caffe doesn’t average over batch size, So I used size_average=False in loss = torch.nn.L1Loss(size_average=False)(pred_bb, gt_bb.float())

I even modified the PyTorch SGD to be as same as Caffe SGD as mentioned here

and the optimizers are configured as below: lr = 1e-06 and momentum=0.9 and gamma=0.1, step_size=1 (where as in caffe number of steps is number of iteration which is set to 100000)

            optimizer = CaffeSGD(params,
                                 lr=self.hparams.lr,
                                 momentum=self.hparams.momentum,
                                 weight_decay=self.hparams.wd
             scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.hparams.lr_step, gamma=self.hparams.gamma

Debugging

To compare the implementation, I turned off the dropout (setting p=0) and verified two iterations of the forward and backward pass with weight update. Both Caffe and PyTorch with the above code gave the same loss and weight updates.

Following are the minor changes

  • Caffe uses a batch size of 50, whereas due to the way I have implemented dataloader, I use a batch size of 44. I hope this shouldn’t matter.
  • Random augmentations are done for each frame to move the bounding box using uniform random distribution. (I have used torch random number generator with num_workers=6) I have plotted the random number generated outputs and found them to be uniform and non-repetative for a single batch.

Following are the differences in PyTorch vs Caffe

Random number generation is different. And the dropout of activations is different. These shouldn’t be the issue as my understanding.

Do you think I’m missing something here? Let me know if you need more input to analyze this issue.

I look forward to the discussion and help.

This might be a minor issue, but why can’t you use the same batch size in your PyTorch code?

Also, could you explain, how you’ve implemented the transformation?
I assume you’ve transformed the input image first and then you’ve used the same transformation on the bounding box?

Hi, @ptrblck Thank you for the reply.

Here is why I use a different batch size:

  • In Caffe, they go through each example pair. For each image, they get the augmentations for the current frame which is based on the motion-model

In Caffe, for each image pair with augmentation, there are 11 images. They fill a list of images where it gets filled to 50, rest they keep it for the next batch.

Augmented images: (generated from PyTorch code, ignore the green boxes)

.

My intention was to take advantage of num_workers in PyTorch data loaders, where I feed 4 examples and then augment them in collate_fn, wherein total I get 44 images for each iteration.

The PyTorch code is below:

import sys
from pathlib import Path

import numpy as np
import torch
from loguru import logger
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

try:
    from goturn.dataloaders.alov import AlovDataset
    from goturn.dataloaders.imagenet import ImageNetDataset
    from goturn.dataloaders.sampler import sample_generator
    from goturn.helper.image_io import resize
except ImportError:
    logger.error('Please run $source settings.sh from root directory')
    sys.exit(1)


class GoturnDataloader(Dataset):
    """Docstring for goturnDataloader. """

    def __init__(self, imagenet_path, alov_path, mean_file=None, isTrain=True,
                 val_ratio=0.005, width=227, height=227,
                 dbg=False):
        """Dataloader initialization for goturn tracker """

        # ALOV
        img_dir = Path(imagenet_path).joinpath('images')
        ann_dir = Path(imagenet_path).joinpath('gt')
        self._imagenetD = ImageNetDataset(str(img_dir), str(ann_dir),
                                          isTrain=isTrain,
                                          val_ratio=val_ratio)

        # Imagenet
        img_dir = Path(alov_path).joinpath('images')
        ann_dir = Path(alov_path).joinpath('gt')
        self._alovD = AlovDataset(str(img_dir), str(ann_dir),
                                  isTrain=isTrain, val_ratio=val_ratio)

        # sample generator
        self._sample_gen = sample_generator(5, 15, -0.4, 0.4, dbg=dbg)
        self._kGenExPerImage = 10

        self._images = []
        self._targets = []
        self._bboxes = []

        self._width = width
        self._height = height
        if mean_file:
            self._mean = np.load(mean_file)
        else:
            self._mean = np.array([104, 117, 123])

        self._minDataLen = min(len(self._imagenetD), len(self._alovD))
        
        self._maxDataLen = max(len(self._imagenetD), len(self._alovD))

    def __len__(self):
        ''' length of the total dataset, is max of one of the dataset '''
        return self._maxDataLen

    def __getitem__(self, idx):
        """Get the current idx data
        @idx: Current index for the data
        """
        if self._minDataLen == len(self._imagenetD):
            imagenet_pair = self._imagenetD[idx % self._minDataLen]
            alov_pair = self._alovD[idx]
        else:
            imagenet_pair = self._imagenetD[idx]
            alov_pair = self._alovD[idx % self._minDataLen]

        return imagenet_pair, alov_pair

    def collate(self, batch):
        ''' Custom data collation for alov and imagenet
        @batch: batch of data
        '''
        self._images = []
        self._targets = []
        self._bboxes = []

        for i, batch_i in enumerate(batch):
            for i, (img_prev, bbox_prev, img_cur, bbox_cur) in enumerate(batch_i):
                self._sample_gen.reset(bbox_cur, bbox_prev, img_cur,
                                       img_prev)
                self.__make_training_samples()

        for i, (im, tar, bbox) in enumerate(zip(self._images,
                                                self._targets,
                                                self._bboxes)):

            im = resize(im, (self._width, self._height)) - self._mean
            # im = im / 255.
            self._images[i] = np.transpose(im, axes=(2, 0, 1))
            tar = resize(tar, (self._width, self._height)) - self._mean
            # tar = tar / 255.
            self._targets[i] = np.transpose(tar, axes=(2, 0, 1))
            self._bboxes[i] = np.array([bbox.x1, bbox.y1, bbox.x2,
                                        bbox.y2])

        images = torch.from_numpy(np.stack(self._images))
        targets = torch.from_numpy(np.stack(self._targets))
        bboxes = torch.from_numpy(np.stack(self._bboxes))

        return images, targets, bboxes

    def __make_training_samples(self):
        """
        1. First decide the current search region, which is
        kContextFactor(=2) * current bounding box.
        2. Crop the valid search region and copy to the new padded image
        3. Recenter the actual bounding box of the object to the new
        padded image
        4. Scale the bounding box for regression
        """

        sample_gen = self._sample_gen
        images = self._images
        targets = self._targets
        bboxes = self._bboxes

        image, target, bbox_gt_scaled = sample_gen.make_true_sample()

        images.append(image)
        targets.append(target)
        bboxes.append(bbox_gt_scaled)

        # Generate more number of examples
        images, targets, bbox_gt_scaled = sample_gen.make_training_samples(self._kGenExPerImage, images, targets, bboxes)


if __name__ == "__main__":
    imagenet_path = '/media/nthere/datasets/ISLVRC2014_Det/'
    alov_path = '/media/nthere/datasets/ALOV/'
    objGoturn = GoturnDataloader(imagenet_path, alov_path, dbg=False)

    dataloader = DataLoader(objGoturn, batch_size=2, shuffle=True,
                            num_workers=0, collate_fn=objGoturn.collate)

    for i, *data in tqdm(enumerate(dataloader), desc='Loading imagenet/alov', total=len(objGoturn), unit='files'):
        pass

i will be happy to share more information.

Hi @ptrblck, in order to verify the forward, backward pass. I kept the input same for 100 iterations for both caffe and pytorch implementation.

I see the loss of Caffe completely different from PyTorch after few iterations (observe at the end)

PyTorch           | Caffe
-------------------------
757.2705688476562   757.271
942.8200073242188   942.82
799.31884765625     799.319
812.5123291015625   812.512
667.9791259765625   667.979
413.325439453125    413.325
562.3018188476562   562.302
350.0400390625      350.04
297.4466552734375   297.447
343.02001953125     343.02
394.5355224609375   394.536
228.69955444335938  228.7
251.35780334472656  251.358
307.5043640136719   307.504
223.70187377929688  223.702
233.28732299804688  233.288
315.7559814453125   315.756
226.81146240234375  226.811
181.53457641601562  181.533
206.6217041015625   206.625
210.15628051757812  210.157
212.7760009765625   212.775
209.96849060058594  209.974
214.49813842773438  214.493
183.6826629638672   183.672
207.22445678710938  207.245
168.16375732421875  168.184
201.4786376953125   201.458
141.69427490234375  141.673
255.06842041015625  255.063
277.982421875       278.047
228.63650512695312  228.61
265.3466796875      265.355
245.4405059814453   245.387
217.6405792236328   217.715
243.95590209960938  244.007
225.05352783203125  225.017
202.35430908203125  202.491
242.1893768310547   242.393
196.67547607421875  197.126
221.29019165039062  221.316
229.14529418945312  229.14
208.05859375        207.983
201.7965087890625   201.89
204.59408569335938  204.142
222.6988525390625   222.059
212.80035400390625  213.449
262.5123291015625   262.53
153.87637329101562  155.095
238.25704956054688  239.609
284.07769775390625  284.95
194.6011962890625   197.989
171.880126953125    171.607
264.2562255859375   267.292
202.17262268066406  201.321
149.88726806640625  146.92
213.33978271484375  219.864
294.94012451171875  298.808
206.38232421875     201.994
184.72096252441406  194.253
255.96823120117188  237.341
211.44276428222656  194.038
151.47923278808594  149.074
172.3814697265625   154.429
179.71731567382812  183.083
151.15765380859375  164.78
157.54025268554688  161.271
191.8701629638672   195.957
166.78912353515625  162.826
166.08004760742188  164.285
177.39981079101562  184.853
170.86624145507812  160.088
165.30255126953125  169.323
145.57489013671875  156.784
194.71859741210938  151.346
238.37454223632812  264.286
175.21041870117188  192.347
136.05355834960938  162.889
172.6317138671875   155.631
208.10299682617188  185.586
221.0286865234375   217.085
173.4173583984375   155.914
253.78131103515625  244.531
248.7365264892578   249.555
194.50607299804688  183.464
212.595947265625    227.499
252.08670043945312  224.977
243.98223876953125  214.467
175.81005859375     139.689
200.88902282714844  170.642
168.1161346435547   189.527
210.65390014648438  243.979
246.09844970703125  232.952
207.2693634033203   178.198
124.55677795410156  182.764
248.9339599609375   264.27
230.8294677734375   216.812
159.265625          130.281
178.865966796875    229.354
157.16854858398438  207.201
224.73690795898438  204.391

Do you have an idea or suggestion on why this could be?

@smth Any suggestions or idea ?

That’s some great debugging.
Based on the output it seems the losses are quite shaky around ~200.
I guess the difference might be due to accumulated rounding errors in both approaches.

You’ve mentioned that the final result is worse in PyTorch, right?
Is this reproducible using different seeds for each framework, i.e. does Caffe always converge to a better accuracy?

Thank you for the reply @ptrblck.

Just to be more transparent to you on the complete setup and also to see if I missed my attention to any of the details. I’m sharing more info to narrow down the problem:

  • The network is trained with 5-conv layer (weights freezed) and 4-FC layers (trained). The network architecture is shown below. Here I have kept dropout ratio = 0.0 for both Caffe and PyTorch, to observe if the outputs and updates are the same. I made weight decay = 0 as well in both.
        self._net = CaffeNet(pretrained_model_path=pretrained_model)
        dropout_ratio = 0.0
        self._classifier = nn.Sequential(nn.Linear(256 * 6 * 6 * 2, 4096),
                                         nn.ReLU(inplace=True),
                                         nn.Dropout(dropout_ratio),
                                         nn.Linear(4096, 4096),
                                         nn.ReLU(inplace=True),
                                         nn.Dropout(dropout_ratio),
                                         nn.Linear(4096, 4096),
                                         nn.ReLU(inplace=True),
                                         nn.Dropout(dropout_ratio),
                                         nn.Linear(4096, num_output))
  • At the 100th iteration, I observed the output of conv-5 layer is the same, both in Caffe and PyTorch. This concludes that my inputs are the same and no errors made in this.

  • L1loss is Caffe is implemented as below:
    Power layer implements -1 * gt. Eltwise layer does the element-wise sum (pred - 1*gt) and the Reduction layer does the summation to scalar loss.

layer {
  name: "neg"
  bottom: "bbox"
  top: "bbox_neg"
  type: "Power"
  power_param {
    power: 1
    scale: -1
    shift: 0
  }
}
layer {
  name: "flatten"
  type: "Flatten"
  bottom: "bbox_neg"
  top: "bbox_neg_flat"
}

layer {
  name: "subtract"
  type: "Eltwise"
  bottom: "fc8"
  bottom: "bbox_neg_flat"
  top: "out_diff"
}
layer {
  name: "abssum"
  type: "Reduction"
  bottom: "out_diff"
  top: "loss"
  loss_weight: 1
  reduction_param {
    operation: 2
  }
}

and for PyTorch I use.

loss = torch.nn.L1Loss(size_average=False)(pred_bb, gt_bb.float())

Conclusion and Observations:

  • As a whole, whole learning(forward/backward/update) is happening in FC layers. These updates are different in Caffe and PyTorch. To be debugged.

Training with different SEED

  • I haven’t tried with different seed in Caffe. I will take a day to retrain with different seed. I will do that and share.

Hi @ptrblck,

  • With different SEED, Caffe still converges better and gives a much better model than PyTorch
  • I also modified the data loader to have similar to Caffe (not worrying much about the discrepancy of the updates). Still the performance of the end model is not any better.

Questions

  • Do you think I should worry about how PyTorch and Caffe updates are happening ?
    so far my idea was if the Caffe and PyTorch updates are very similar over good number of iterations, then I can say problem lies somewhere else.!

Do you have any suggestions for further debugging.? :slight_smile:

So the loss is the same for the first couple of epochs, but the gradients differ?
Are you seeing the difference in gradients from the beginning (which is strange, as the loss stays equal for some iterations) or after a while?

@ptrblck thank you for the reply.

Here is the gradient for the first FC layer for Caffe and PyTorch:
Top: Caffe gradient Bottom: PyTorch grad . As I can see the gradients are rounded off.

Yes, I also couldn’t make sense of why loss stays the same for quite a while and changes weirdly in some of the iterations. One of my guesses is since its only FC layers that are updated, loss = sum(w.x + b), which might lead to that error. But still strange behaviour.

The output is most likely just rounded due to the default print options.
If you want to have more precision, you could use torch.set_printoptions.

Anyway, I would expect that different frameworks might diverge after a while just because of the limited floating point precision and a possibly different operation order.
However, it’s still concerning that your Caffe model converges to a better final state than your PyTorch model.

Thank you @ptrblck. My bad

Here is the updated one:

Yes, I agree with you as well. So, I suspected that my sampling of data for the network should be an issue. To not miss anything, I made the data loader to fetch similarly sampled data just as in Caffe to the network and I found the same kind of poor convergence as before without modifying the data loader

1 Like

Hi @ptrblck One question.

Since different frameworks might converge differently. I was mostly looking into how data augmentation happens in Caffe and PyTorch. In Caffe C++ random numbers (uniform distribution) are generated for data augmentation as in the code below:

Data fetching does happen in a single process, whereas in PyTorch code, I am using num_workers > 0.

I use PyTorch API to generate random numbers as below and can I assume its thread-safe?

import math

import torch

RAND_MAX = 2147483647


def sample_rand_uniform():
    """TODO: Docstring for sample_rand_uniform.
    :arg1: TODO
    :returns: TODO
    """
    # return ((random.randint(0, RAND_MAX) + 1) * 1.0) / (RAND_MAX + 2)
    rand_num = torch.randint(RAND_MAX, (1, 1)).item()
    return ((rand_num + 1) / (RAND_MAX + 2))
    # return torch.rand(1).item()


def sample_exp_two_sides(lambda_):
    """TODO: Docstring for sample_exp_two_sides.
    :returns: TODO
    """

    # pos_or_neg = random.randint(0, RAND_MAX)
    pos_or_neg = torch.randint(RAND_MAX, (1, 1)).item()
    if (pos_or_neg % 2) == 0:
        pos_or_neg = 1
    else:
        pos_or_neg = -1

    rand_uniform = sample_rand_uniform()
    return math.log(rand_uniform) / (lambda_ * pos_or_neg)

Each worker process will be seeded with self.base_seed + worker_id as seen here.

You might additionally pass a worker_init_fn to seed the workers manually.

1 Like

Thank you @ptrblck for the code reference.

With the following setting say:

Total number of samples: 100
Number of workers: 6
Each worker: Init with different seed

If Im trying to generate random numbers from uniform distribution between 1-100 using 6 workers with different seed. Is it guaranteed I will get random number between 1-100 with almost equal probability in one epoch?

I shall also try writing a sample code, but eager to know

Yes, that should be the case. Please let us know, if you see any other behavior.

1 Like