Got cuda out of memory while implementing LeNet5

I have 4gb ram ,2gb ram gpu and when i am trying lenet-5 for kaggle facial keypoints dataset i m getting RuntimeError: CUDA error: out of memory. What should I do and what is causing this?

import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(6, 64, 2)
        self.conv3 = nn.Conv2d(64, 128, 2)
        self.fc1 = nn.Linear(128 * 5 * 5, 500)
        self.fc2 = nn.Linear(500, 500)
        self.fc3 = nn.Linear(500, 30)
    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
dtype = torch.float
device = torch.device("cuda:0")
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 2140, 9216, 100, 30
def train(model,x,y,criterion,optimizer):
    model.train()
    y_pred = model(x)
    loss = criterion(y_pred, y)
    print('train-loss',t, loss.item(),end=' ')
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()
def valid(model,x_valid,y_valid,criterion):
    model.eval()
    y_pred = model(x_valid)
    loss = criterion(y_pred, y_valid)
    print('test-loss',t, loss.item(),end=' ')
    return loss.item()
# Create random Tensors to hold inputs and outputs
X_train=X_train.reshape(-1, 1, 96, 96)
X_valid=X_valid.reshape(-1, 1, 96, 96)
x_train =  torch.tensor(torch.from_numpy(X_train),device=device,dtype=dtype)
y_train =   torch.tensor(torch.from_numpy(Y_train),device=device,dtype=dtype)
x_valid =  torch.tensor(torch.from_numpy(X_valid),device=device,dtype=dtype)
y_valid =   torch.tensor(torch.from_numpy(Y_valid),device=device,dtype=dtype)
model = LeNet().to(device)
loss_train=[]
loss_valid=[]
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(400):
    loss_train.append(train(model,x_train,y_train,criterion,optimizer))
    loss_valid.append(valid(model,x_valid,y_valid,criterion))
    print()

Your code might have some typos.
self.conv2 should accept 32 in_channels based on your forward method.
Also, are you resizing the images?
If not, I think your fc1 layer should accept 128 * 11 * 11 features. At least that’s working with input images of 96x96.
Using a batch size of 128 the training takes approx. 1.6GB so your GPU have enough RAM.
Are other processes filling up your GPU maybe?

1 Like

Using convolutional neural nets to detect facial keypoints tutorial — Daniel Nouri's Blog in this tutorial he has used following nn architecture

Blockquote
net2 = NeuralNet(
layers=[
(‘input’, layers.InputLayer),
(‘conv1’, layers.Conv2DLayer),
(‘pool1’, layers.MaxPool2DLayer),
(‘conv2’, layers.Conv2DLayer),
(‘pool2’, layers.MaxPool2DLayer),
(‘conv3’, layers.Conv2DLayer),
(‘pool3’, layers.MaxPool2DLayer),
(‘hidden4’, layers.DenseLayer),
(‘hidden5’, layers.DenseLayer),
(‘output’, layers.DenseLayer),
],
input_shape=(None, 1, 96, 96),
conv1_num_filters=32, conv1_filter_size=(3, 3), pool1_pool_size=(2, 2),
conv2_num_filters=64, conv2_filter_size=(2, 2), pool2_pool_size=(2, 2),
conv3_num_filters=128, conv3_filter_size=(2, 2), pool3_pool_size=(2, 2),
hidden4_num_units=500, hidden5_num_units=500,
output_num_units=30, output_nonlinearity=None,

Blockquote
and for reshaping he used

python
def load2d(test=False, cols=None):
    X, y = load(test=test)
    X = X.reshape(-1, 1, 96, 96)
    return X, y

where he is not resizing it to 32 channels

thanks for the reply

Sorry for making it not clear.
I mean the in_channels of conv2 are set currently to 6, although conv1 returns 32 channels.
So your model is currently not working.
Try the following and you’ll get an error for conv2:

model = LeNet()
x = torch.randn(1, 1, 96, 96)
output = model(x)
> RuntimeError: Given groups=1, weight of size [64, 6, 2, 2], expected input[1, 32, 47, 47] to have 6 channels, but got 32 channels instead

Also, after fixing this issue, you’ll get another size mismatch error for your fc1 layer, as the shape of x right before flattening it is [1, 128, 11, 11]. So fc1 should take in_featutres=128*11*11.

After fixing this issue, you’ll get another error as you are reusing fc3. Removing the x = F.relu(self.fc3(x)) line finally makes your model run.

I’m curious why the model seems to be working for you.

1 Like

thanks a lot , now i understood how to calculate no of parameters for the next conv layer, also i tried to implement data loader class

from torch.utils.data.dataset import Dataset

class MyCustomDataset(Dataset):
    def __init__(self, x,y,dtype):
        self.x=torch.tensor(torch.from_numpy(x),dtype=dtype)
        self.y=torch.tensor(torch.from_numpy(y),dtype=dtype)
        self.dtype=dtype
        self.data_len=len(x)
    def __getitem__(self, index):
        # stuff
        img=self.x[index]
        label=self.y[index]
        return (img, label)

    def __len__(self):
        return self.data_len

can i implement data augmentation inside this data loader class for each 128 batches, like in the tutorial they have created FlipBatchIterator class for augmenting 50% of data. Also i have doubt that how he is calculating the training and validation loss, does he is calculating the loss over a batch and then averaging it over 1 epoch or he is just summing up the loss for batches

Sure, there are several ways to implement data augmentation.

One way would be to use torchvision.transforms on images. As these transformations often work only on images, we would need to transform the numpy arrays into images first, augment the data, and finally transform them to tensors.

Another way would be to just implement these transformations by ourselves, as we already have the tensor data.

Anyway, in both cases we would have to take care of the targets as well, since the keypoints would have to be e.g. flipped accordingly to the image.

In the first case, we could use torchvision’s functional API. Here is a small example I’ve written a while ago.

You could just reimplement the blog post’s data augmentation as the flip indices etc. is already provided:

class MyCustomDataset(Dataset):
    def __init__(self, x,y,dtype):
        self.x=torch.from_numpy(x).to(dtype=dtype).clone()
        self.y=torch.from_numpy(y).to(dtype=dtype).clone()
        self.dtype=dtype
        self.data_len=len(x)
        self.flip_indices = [
            (0, 2), (1, 3),
            (4, 8), (5, 9), (6, 10), (7, 11),
            (12, 16), (13, 17), (14, 18), (15, 19),
            (22, 24), (23, 25),
        ]
        
    def __getitem__(self, index):
        # stuff
        img=self.x[index].clone()
        label=self.y[index].clone()
        # Transform every second sample
        if random.randint(0, 1) == 1:
            print('Flipping image')
            img = img.flip(2)
            label[::2] = label[::2] * -1
            
            for a, b in self.flip_indices:
                label[a], label[b] = label[b].clone(), label[a].clone()
        
        return (img, label)

    def __len__(self):
        return self.data_len

You have to add some clone() calls, as otherwise the original data will be modified or the label swap won’t work.

Using the Dataset you just have to handle a single sample. The DataLoader will automatically create batches using the Dataset.

I’m not sure, how Lasagne calculates the losses, but I assume in both cases the mean loss of all batches is averaged over the epoch.

1 Like

Is it right to define optimizer at every epoch for varying learning rate and momentum?

lr_arr=np.linspace(0.03, 0.0001, epochs)
momentum_arr=np.linspace(0.9, 0.999, epochs)
for epoch in range(1, epochs + 1):
        optimizer = torch.optim.SGD(model_two.parameters(),lr=lr_arr[epoch-1], momentum=momentum_arr[epoch-1], nesterov=True)
        train_loss_net2.append(train( model_two, device, train_loader, optimizer,criterion,epoch))
        test_loss_net2.append(test(model_two, device,criterion,test_loader))
end = datetime.datetime.now()
print(end-start)

For optim.SGD this should work. However, if you use another optimizer with internal states, e.g. running estimates, the recreation will clear out these buffers and you will most likely see a spiking loss curve.
In that case I would recommend to use the optimizer.param_group to manipulate the internal values.
Also, have a look at optim.lr_scheduler. These scheduler allow an easy manipulation of the learning rate using different methods.

1 Like

thanks a lot, is this implementation correct

lambda1 = lambda epoch: 0.03-epoch*(0.03-.0.0001)/epochs
optimizer = torch.optim.SGD(model_three.parameters(),nesterov=True)
scheduler = LambdaLR(optimizer, lr_lambda=lambda1)
for epoch in range(100):
      scheduler.step()
      train(...)
      validate(...)

how can I vary momentum in this scheduler, i have to vary both (learning rate and momentum)

The implementation won’t work, as LambdaLR is using a multiplicative factor to manipulate the learning rate.
Your optimizer is also missing the lr argument.
In your use case, I think it would be the easiest approach to use your initial lr_arr and momentum_arr and to manipulate the optimizer.param_group directly.