RuntimeError: Expected object of backend CPU but got backend CUDA for argument #4 'mat1'

Hi Everyone,

I am getting the following error, when I try to run the program in a GPU.

Traceback (most recent call last):
  File "train_model.py", line 219, in <module>
    image_size=64, color=True, device = device)
  File "train_model.py", line 198, in train_model
    outputs = cpc_model(inputx, inputy)
  File "/home/pranavan/anaconda3/envs/capsule/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "train_model.py", line 155, in forward
    x = self.n_pl(x)
  File "/home/pranavan/anaconda3/envs/capsule/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "train_model.py", line 123, in forward
    outputs.append(self.fc1_layers[i](x).unsqueeze(1))
  File "/home/pranavan/anaconda3/envs/capsule/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/pranavan/anaconda3/envs/capsule/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 92, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/pranavan/anaconda3/envs/capsule/lib/python3.7/site-packages/torch/nn/functional.py", line 1406, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of backend CPU but got backend CUDA for argument #4 'mat1'

This is my source code. I have copied below.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

from data_utils import SortedNumberGenerator


# Done
# Similar to flatten layer of keras
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size()[0], -1)


# Done - needs to be tested
# Auto regressive model
class Autoregressor(nn.Module):

    def __init__(self, no_of_layers=256):
        super(Autoregressor, self).__init__()
        # very bad to hard code input shape. change in the long run
        self.gru = nn.GRU(input_size=128, hidden_size=256, num_layers=no_of_layers, batch_first=True)

    def forward(self, x):
        x, y = self.gru(x)
        x = x[:, x.size(1) - 1, :].squeeze()  # getting the last output like  keras - return_sequences = False
        return x


# Done - needs to be tested
# decoder of the neural network
class Encoder(nn.Module):
    def __init__(self, codesize=128):
        super(Encoder, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2)
        self.batch_norm1 = nn.BatchNorm2d(num_features=64)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2)
        self.batch_norm2 = nn.BatchNorm2d(num_features=64)

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2)
        self.batch_norm3 = nn.BatchNorm2d(num_features=64)

        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2)
        self.batch_norm4 = nn.BatchNorm2d(num_features=64)

        self.flatten1 = Flatten()
        self.fc1 = nn.Linear(in_features=576, out_features=256)

        self.batch_norm5 = nn.BatchNorm1d(num_features=256)

        self.fc2 = nn.Linear(in_features=256, out_features=codesize)

    def forward(self, x):
        x = F.leaky_relu(self.batch_norm1(self.conv1(x)))
        x = F.leaky_relu(self.batch_norm2(self.conv2(x)))
        x = F.leaky_relu(self.batch_norm3(self.conv3(x)))
        x = F.leaky_relu(self.batch_norm4(self.conv4(x)))

        x = self.flatten1(x)
        x = self.fc1(x)

        x = F.leaky_relu(self.batch_norm5(x))

        x = self.fc2(x)

        return x


# Done - needs to be tested
# keras TimeDistributed Layer alternative in pytorch - found in https://discuss.pytorch.org/t/any-pytorch-function-can-work-as-keras-timedistributed/1346/4
class TimeDistributed(nn.Module):
    def __init__(self, module):
        super(TimeDistributed, self).__init__()
        self.module = module

    def forward(self, x):
        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().view(-1, x.size(-3), x.size(-2), x.size(-1))  # (samples * timesteps, C, H, W)
        y = self.module(x_reshape)
        y = y.contiguous().view(x.size(0), -1, y.size(-1))

        return y


# Done - needs to be tested
class CPCLayer(nn.Module):
    ''' Computes dot product between true and predicted embedding vectors '''

    def __init__(self, **kwargs):
        super(CPCLayer, self).__init__(**kwargs)

    def forward(self, x, y):
        # Compute dot product among vectors
        preds, y_encoded = x, y
        dot_product = torch.mean(y_encoded * preds, dim=-1)
        dot_product = torch.mean(dot_product, dim=-1, keepdim=True)  # average along the temporal dimension

        # Keras loss functions take probabilities
        dot_product_probs = torch.sigmoid(dot_product)

        return dot_product_probs


# Done - Need to be tested
# A thorough check is needed
class NetworkPredictionLayer(nn.Module):

    def __init__(self, code_size, predict_terms):
        super().__init__()
        self.predict_terms = predict_terms
        self.code_size = code_size
        self.fc1_layers = []
        for i in range(predict_terms):
            self.fc1_layers.append(nn.Linear(in_features=256, out_features=self.code_size))

    def forward(self, x):
        outputs = []
        for i in range(self.predict_terms):
            outputs.append(self.fc1_layers[i](x).unsqueeze(1))

        outputs = torch.cat(outputs, dim=1)

        return outputs


# WIP
class CPCModel(nn.Module):

    def __init__(self, code_size=128, terms=4, pred_terms=4):
        super(CPCModel, self).__init__()
        self.code_size = code_size
        self.ne = Encoder(codesize=code_size)
        self.terms = terms
        self.pred_terms = pred_terms
        self.td_1 = TimeDistributed(module=self.ne)
        self.td_2 = TimeDistributed(module=self.ne)
        self.n_ar = Autoregressor(no_of_layers=256)
        self.n_pl = NetworkPredictionLayer(code_size=code_size, predict_terms=terms)
        self.pred_terms = pred_terms
        self.cpc = CPCLayer()

    def forward(self, x, y):
        # print('Input X shape : ', x.shape)

        x = self.td_1(x)
        # print('Output from time distributed - 1', x.shape)

        x = self.n_ar(x)
        # print('Output from autoregressor : ', x.shape)

        x = self.n_pl(x)
        # print('Output from network prediction layer : ', x.shape)

        y = self.td_2(y)
        # print('output from time distributed 2', y.shape)

        z = self.cpc(x, y)
        # print('output from cpc layer', z.shape)

        return z


def train_model(epochs, batch_size, output_dir, code_size, lr=1e-4, terms=4, predict_terms=4, image_size=28,
                color=False, device = 'cuda'):
    # Prepare data
    train_data = SortedNumberGenerator(batch_size=batch_size, subset='train', terms=terms,
                                       positive_samples=batch_size // 2, predict_terms=predict_terms,
                                       image_size=image_size, color=color, rescale=True)
    validation_data = SortedNumberGenerator(batch_size=batch_size, subset='valid', terms=terms,
                                            positive_samples=batch_size // 2, predict_terms=predict_terms,
                                            image_size=image_size, color=color, rescale=True)

    cpc_model = CPCModel().to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(cpc_model.parameters(), lr=0.001)

    for lam in range(epochs):
        print('Epoch number : \t', (lam + 1))
        running_loss = 0.0
        print("Started Training")
        for i, data in enumerate(train_data):
            inputs, labels = data
            # print(type(inputs), type(labels))
            labels = torch.tensor(labels).to(device)
            inputx, inputy = torch.tensor(inputs[0]), torch.tensor(inputs[1])
            inputx = inputx.permute((0, 1, 4, 2, 3))
            inputy = inputy.permute((0, 1, 4, 2, 3))
            inputx, inputy = inputx.to(device), inputy.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = cpc_model(inputx, inputy)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if i % 20 == 19:
                print('[%d, %5d] loss: %.3f' % (1, i + 1, running_loss / 20))
                running_loss = 0.0

    print("Finished Training!!!")

    # train_transformer = transforms.Compose([transforms.ToTensor()])

    # print(type(train_data), type(validation_data))


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    train_model(epochs=10, batch_size=4, output_dir='models/64x64', code_size=128, lr=1e-3, terms=4, predict_terms=4,
                image_size=64, color=True, device = device)
    # print(device)
    # x = torch.rand((32, 4, 3, 64, 64))
    # y = torch.rand((32, 4, 3, 64, 64))
    # cpc_model = CPCModel()
    # print('Final shape : ', cpc_model(x, y).shape)

I appreciate your help. Thanks in advance.

Lists of modules are not handled in nn.Module, you need to either use nn.Sequential or nn.ModuleList. Check this thread for more detalis on the reasoning and here for the doc on nn.ModuleList, which is probably closer to what you want.

2 Likes

Thanks @alex.veuthey . That is the perfect solution. It resolved the issue. Thanks again for your time.

1 Like