Cannot figure out why training not going well for hand pose estimation in pytorch

I am implementing Region Ensemble Network in Pytorch.
It takes in a resized depth image and outputs the pixel coordinates of the 21 joints.
I cannot figure out in any way why the training loss keeps increasing. Even when it does decrease when i change the hyper parameters, the outputs are very similar. I am sure that the ground truth and data has no problem. For example, when i draw the ground truth onto the 96x96 image, it looks like this image

I have tried decreasing/increasing the learning rate (ranging from 0.005 to 0.000005), normalizing and not normalizing it, data augmentation, tried overfitting it on a small non-augmented subset of the dataset (MSRA hand gesture dataset), but it is not working at all. You can see the full implementation at Github

For reference, this is what the ground truth of 21 x 2 x,y joints coordinate is:

tensor([ 49.0274,  65.6594,  53.2944,  46.7441,  55.6024,  37.6823,
         57.4913,  30.6067,  59.4096,  23.9927,  46.4607,  45.0358,
         45.9472,  33.7956,  45.6864,  25.9997,  45.4752,  18.4846,
         40.3416,  45.7006,  37.4622,  36.1925,  35.6829,  30.8469,
         34.2946,  27.1403,  33.4312,  49.3788,  30.2387,  44.8089,
         28.7085,  42.4615,  26.9225,  39.7115,  55.4961,  62.4478,
         61.8277,  57.6757,  66.9980,  53.5504,  73.2470,  50.6870], dtype=torch.float64)

But i am getting this below after training for 100 epoches, with decaying learning rate multiplier of 0.1 every 10 epoches.

[[-0.08046436 -0.0457227 ]
 [-0.08309992 -0.05746359]
 [-0.08415958 -0.05046146]
 [-0.0773853  -0.05147431]
 [-0.06778512 -0.06056155]
 [-0.0672973  -0.05785097]
 [-0.07014311 -0.0637675 ]
 [-0.08236345 -0.05145769]
 [-0.06171172 -0.05182333]
 [-0.07300673 -0.05585903]
 [-0.07764702 -0.0533776 ]
 [-0.07672743 -0.06045451]
 [-0.07716335 -0.05677503]
 [-0.08042111 -0.04958814]
 [-0.07111529 -0.06132633]
 [-0.07012014 -0.05466149]
 [-0.07885809 -0.05328601]
 [-0.07775773 -0.05257958]
 [-0.08589675 -0.05545426]
 [-0.0788433  -0.04867259]
 [-0.06929056 -0.05824893]]

I am using a custom Smooth L1 loss as described in the paper.

class Modified_SmoothL1Loss(torch.nn.Module):

    def __init__(self):
        super(Modified_SmoothL1Loss,self).__init__()

    def forward(self,x,y):
        total_loss = 0
        z = x - y
        for i in range(z.shape[0]):
            for j in range(z.shape[1]):
                total_loss += self._smooth_l1(z[i][j])

        return total_loss/z.shape[0]

    def _smooth_l1(self, z):
        if torch.abs(z) < 0.01:
            loss = self._calculate_MSE(z)
        else:
            loss = self._calculate_L1(z)

        return loss

    def _calculate_MSE(self, z):
        return 0.5 *(torch.pow(z,2))

    def _calculate_L1(self,z):
        return 0.01 * (torch.abs(z) - 0.005)

This is my network:


class RegionEnsemble(nn.Module):

    def __init__(self, feat_size=12):
        assert((feat_size/4).is_integer())
        super(RegionEnsemble, self).__init__()
        self.feat_size = feat_size
        self.grids = nn.ModuleList()
        for i in range(9):
            self.grids.append(self.make_block(self.feat_size))

    def make_block(self, feat_size):
        size = int(self.feat_size/2)
        return nn.Sequential(nn.Linear(64*size*size, 2048), nn.ReLU(), nn.Dropout(), nn.Linear(2048,2048), nn.ReLU(), nn.Dropout())

    def forward(self, x):

        midpoint = int(self.feat_size/2)
        quarterpoint1 = int(midpoint/2)
        quarterpoint2 = int(quarterpoint1 + midpoint)
        regions = []
        ensemble = []

        #4 corners
        regions += [x[:, :, :midpoint, :midpoint], x[:, :, :midpoint, midpoint:], x[:, :, midpoint:, :midpoint], x[:, :, midpoint:, midpoint:]]
        #4 overlapping centers

        regions += [x[:, :, quarterpoint1:quarterpoint2, :midpoint], x[:, :, quarterpoint1:quarterpoint2, midpoint:], x[:, :, :midpoint, quarterpoint1:quarterpoint2], x[:, :, midpoint:, quarterpoint1:quarterpoint2]]
        # middle center
        regions += [x[:, :, quarterpoint1:quarterpoint2, quarterpoint1:quarterpoint2]]

        for i in range(0,9):
            out = regions[i]
            # print(out.shape)
            out = out.contiguous()
            out = out.view(out.size(0),-1)
            out = self.grids[i](out)
            ensemble.append(out)

        out = torch.cat(ensemble,1)

        return out



class Residual(nn.Module):

    def __init__(self, planes):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(in_channels= planes, out_channels=planes, kernel_size = 3,  padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels= planes, out_channels=planes, kernel_size = 3,  padding=1)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.relu(out)

        out = self.conv2(out)

        out += residual
        return out

class REN(nn.Module):

    def __init__(self, args):
        super(REN, self).__init__()
        feat = np.floor(((args.input_size - 1 -1)/2) +1)
        feat = np.floor(((feat - 1-1)/2) +1)
        feat = np.floor(((feat - 1-1)/2) +1)
        #nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
        self.conv0 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size = 3, padding=1)
        self.relu0 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size = 3, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.relu1 = nn.ReLU()
        self.conv2_dim_inc = nn.Conv2d(in_channels=16, out_channels=32, kernel_size = 1, padding=0)
        self.relu2 = nn.ReLU()
        self.res1 = Residual(planes = 32)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.relu3 = nn.ReLU()
        self.conv3_dim_inc = nn.Conv2d(in_channels=32, out_channels=64, kernel_size = 1, padding=0)
        self.relu4 = nn.ReLU()
        self.res2 = Residual(planes = 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.relu5 = nn.ReLU()
        self.dropout = nn.Dropout()
        self.region_ens = RegionEnsemble(feat_size=feat)
        #class torch.nn.Linear(in_features, out_features, bias=True)
        self.fc1 = nn.Linear(9*2048, args.num_joints)

    def forward(self, x):

        out = self.conv0(x)
        out = self.relu0(out)

        out = self.conv1(out)
        out = self.maxpool1(out)
        out = self.relu1(out)

        out = self.conv2_dim_inc(out)
        out = self.relu2(out)

        out = self.res1(out)

        out = self.maxpool2(out)
        out = self.relu3(out)

        out = self.conv3_dim_inc(out)
        out = self.relu4(out)

        out = self.res2(out)

        out = self.maxpool3(out)
        out = self.relu5(out)        #relu5
        out = self.dropout(out)


        #slice
        out = self.region_ens(out)
        # flatten the output
        out = out.view(out.size(0),-1)

        out = self.fc1(out)
        return out

This is my training function:

def train(train_loader, model, criterion, optimizer, epoch,args):

    # switch to train mode
    model.train()
    loss_train = []
    expr_dir = os.path.join(args.save_dir, args.name)
    for i, (input, target) in enumerate(train_loader):

        stime = time.time()
        # measure data loading time
        target = target.float()
        target = target.cuda(non_blocking=False)
        input = input.float()
        input = input.cuda()
        # compute output
        output = model(input)

        loss = criterion(output, target)
        # measure accuracy and record loss
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        loss_train.append(loss.data.item())
        optimizer.step()
        np.savetxt(os.path.join(expr_dir, "_iteration_train_loss.out"), np.asarray(loss_train), fmt='%f')
        # measure elapsed time
        if i % args.print_interval == 0:
            TT = time.time() -stime
            print('epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss:.4f}\t'
                  'Time: {time:.2f}\t'.format(
                   epoch, i, len(train_loader), loss=loss.item(), time= TT))

    return [np.mean(loss_train)]

Any inputs on what could have gone wrong is greatly appreciated. Been working for weeks on this but still cannot figure out what is wrong.

The issue was actually a really silly one. When i loaded the ground truths in the dataset, it was loaded into a tensor. For example,

self.all_joints = read_joints()
other reading stuff

When i loaded the data in the __getitem__ function, i would perform random translation/scaling/rotating and normalization on the joints.

joint = self.all_joints[index]
joint = translate(joint) and so on

As i did not know about tensors not being copied as a new tensor but as a reference, the numbers in self.all_joints[index] kept getting smaller and smaller with every loading of that index. When debugging, i only loaded each index once and hence the issue slipped past me.

As a quick fix, i simply changed saved self.all_joints as a numpy array instead

Could you explain a bit more what you mean by:

Did you somehow delete the joints in self.all_joints?

No, they werent deleted, but the translated and normalized joints replaced the original joints in self.all_joints

For example,
This is the first iteration of the ground truth at __getitem__(1) after doing some random augmentation (accurate values):

tensor([ 48.5065,  66.6261,  54.1668,  47.8129,  56.8132,  38.7390,
         58.9134,  31.6378,  60.9088,  24.9512,  47.4044,  45.5503,
         48.5107,  34.2860,  49.2876,  26.3807,  50.0234,  18.8972,
         41.2396,  45.7394,  40.1532,  35.9195,  39.3954,  29.1414,
         38.6335,  22.4301,  34.0548,  48.8389,  31.3979,  42.1586,
         29.6963,  37.8616,  27.8077,  33.0767,  54.5841,  63.2553,
         60.8536,  58.3544,  65.4259,  53.5492,  71.3503,  49.6988], dtype=torch.float64)

calling __getitem__(1) again gives you a smaller number as the random augmentation were performed on the above tensor instead of the original ground truth

tensor([  3.8037,  30.5698,   6.0577,  23.0783,   7.1115,  19.4650,
          7.9478,  16.6372,   8.7423,  13.9746,   3.3648,  22.1773,
          3.8053,  17.6918,   4.1147,  14.5438,   4.4077,  11.5639,
          0.9099,  22.2526,   0.4773,  18.3422,   0.1756,  15.6432,
         -0.1278,  12.9707,  -1.9511,  23.4868,  -3.0091,  20.8267,
         -3.6867,  19.1156,  -4.4387,  17.2102,   6.2238,  29.2275,
          8.7204,  27.2759,  10.5411,  25.3625,  12.9002,  23.8292], dtype=torch.float64)

and again __getitem__(1)

tensor([-13.9973,  16.2119, -13.0997,  13.2288, -12.6801,  11.7899,
        -12.3471,  10.6639, -12.0307,   9.6036, -14.1720,  12.8700,
        -13.9966,  11.0838, -13.8734,   9.8303, -13.7567,   8.6436,
        -15.1496,  12.9000, -15.3219,  11.3428, -15.4420,  10.2681,
        -15.5628,   9.2038, -16.2889,  13.3914, -16.7102,  12.3322,
        -16.9800,  11.6508, -17.2795,  10.8921, -13.0336,  15.6774,
        -12.0394,  14.9003, -11.3144,  14.1384, -10.3750,  13.5278], dtype=torch.float64)

the ground truths were iteratively replaced with every call of that index, which led to really really bad training

OK, I see! Thanks for the clarification. I’ve misunderstood the meaning of “smaller” in this case. :wink:

Just for completeness, it seems you are modifying your data in-place.
If you can just index your joints and augment them using an assignment, this issue should not occur.
Have a look at the following code snippet:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.ones(10, 10)
        
    def __getitem__(self, index):
        x = self.data[index]
        #x -= 1
        x = x - 1
        return x
    
    def __len__(self):
        return len(self.data)
    
dataset = MyDataset()
dataset[0]
print(dataset.data)

If you just run it, the original data will stay the same.
Swap the lines in __getitem__ and your dataset.data will be modified.

1 Like