Getting TypeError mismatch for CPU and GPU despite having pushed model and data to GPU

Hello. I’m trying to run my model with some data and am getting the following error:

TypeError: expected Variable[CPUType] (got torch.cuda.FloatTensor)

I’ve checked some of the answers here and it seemed that I hadn’t pushed my model onto the device yet.

However, I checked the code and I have in fact done that, and I even explicitly pushed it onto the device in the Python Debugger interactive shell and am still getting the same error.

The code is as following:


### Module: main.py
def main():
    config = get_args()
    dataset = Data(config)
    model = GCN(config, dataset.num_features, dataset.num_classes)
    trainer = Trainer(config, model, dataset)

    if torch.cuda.is_available():
        model = model.to('cuda') # I've double checked that torch.cuda.is_available() returns True.

    trainer.train()


### Module solver.py
class Trainer():
    def __init__(self, config, model, dataset):
        self.config = config
        self.num_epochs = self.config.num_epochs
        self.model = model
        self.dataset = dataset

        self.features = self.dataset.features
        self.adj_hat = self.dataset.adj_hat

    def train(self):
        self.model.train()

        optimizer = get_optimizer(self.config, self.model)
        loss_train = nn.NLLLoss()

        for epoch in range(self.num_epochs):
            optimizer.zero_grad()
            output = self.model(self.features, self.adj_hat)


### Module: models.py
class GraphConv(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight_mat = nn.Linear(in_features=in_features, out_features=out_features)

    def forward(self, x, adj_mat):
        weight_prod = torch.DoubleTensor(self.weight_mat(x))
        output = torch.matmul(adj_mat, weight_prod)

       return output


class GCN(nn.Module):
    def __init__(self, config, num_features, num_classes):
        super().__init__()
        self.config = config
        self.num_hidden = self.config.num_hidden
        self.num_classes = num_classes
        self.num_features = num_features
        self.p = self.config.dropout_rate

        self.graphconv1 = GraphConv(in_features=self.num_features, out_features=self.num_hidden)
        self.graphconv2 = GraphConv(in_features=self.num_hidden, out_features=self.num_clases)

    def forward(self, x, adj_hat):
        x = F.relu(self.graphconv1(x, adj_hat))
        x = F.dropout(input=x, p=self.p, training=self.training)
        output = F.softmax(self.graphconv2(x, adj_hat), dim=1)

        return output

The specific line of code that’s triggering the error is the x = F.relu(self.graphconv1(x, adj_hat)) inside the GCN model. I don’t understand because if I put self.model on the device, shouldn’t that take care of this issue?

Thanks in advance!

You are creating new tensors on the CPU in this line of code:

weight_prod = torch.DoubleTensor(self.weight_mat(x))

If you want to change the data type, use out = out.double() instead and make sure it’s the right type for all further calculations.

2 Likes

@ptrblck Hi, ptrblck. Could you please explain it in detail? Since the error is “expected Variable[CPUType]”, while your comment is “creating new tensors on the CPU”. In addition, I don’t understand why we need to modify the code after setting model = model.to('cuda')

The error message tries to be helpful in claiming which type is expected and which types were found.
However, this expectation could be a guess in case both types would be valid for this operation and, if I’m not mistaken, the error message uses the parameter as the expected type and the input as the wrong one (as done here), since this issue is more common.

In this case, each forward pass creates new tensors in the forward method without using a device agnostic approach:

    def forward(self, x, adj_mat):
        weight_prod = torch.DoubleTensor(self.weight_mat(x))

Since no device attribute was used, the new weight_prod tensors will be created on the CPU by default.
Note that this approach would also detach weight_prod from the computation graph, so I would stick to my suggestion in using .double() if the dtype should be changed.

Thank you for your help. Based on previous study and your comments, if I’m not mistaken, the error is due to tensor transferring from cpu to gpu ( model = model.to('cuda')). After this operation, self.weight_mat(x) is on gpu but the operation torch.DoubleTensor(self.weight_mat(x)) can only be done on cpu.

However, I don’t fully understand your solution. Since current error is from weight_prod = torch.DoubleTensor(self.weight_mat(x)), why you suggest operations on other variables out = out.double()?

How about just set

weight_prod = torch.nn.parameter(self.weight_mat(x))

This operation is explicitly creating a CPUTensor, so it’s not a limitation of where this operation can be executed, but what it’s used for.

What the user tried to do seemed to be a transformation of the float32 tensor to a float64 tensor (you could ask him to double check).
This can be easily done via the tensor.double() operation, which neither will change the device (no device mismatch) nor will it detach the tensor.
Your approach creates a new parameter (which won’t be optimized, as it’s depending on the input x, is recreated in each iteration, and is thus unknown to the optimizer), which will also detach the operation from the computation graph and won’t change the dtype to float64, which seems to be the original use case.

1 Like

Thank you for your reply. If I have to build some tensor in forward function, how could I avoid the conflict between cpu tensor and gpu tensor after setting model.cuda()?

In this thread, apaszke just recommends to use nn.parameter. However, in the discussion above, you said

creates a new parameter (which won’t be optimized, as it’s depending on the input x, is recreated in each iteration, and is thus unknown to the optimizer), which will also detach the operation from the computation graph

For me, a beginner on Pytorch, it’s somehow confusing. I can understand your comments on the effects of nn.parameter in forward function (optimizer can not work on it), then what is the right way to build a tensor in the forward function where the tensor is related with both the weight parameter and the output in the same time?

Thank you!

If you want to create a trainable parameter or a tensor, which should be registered in the module during its initialization, you could use an nn.Parameter or register the tensor via self.register_buffer in the __init__ method.
This will make sure to transfer the tensors to the appropriate device when model.to() is called.
However, if you want to create a new tensor in the forward method (which is different than the original question), you could reuse the .device attribute of a known parameter or the input:

def forward(self, x):
    my_new_tensor = torch.randn(1, device=x.device)
    x = x + my_new_tensor
    return x
1 Like

I’m grateful for your help. Now I think I should avoid using nn.parameter in the forward function and use your method instead.

Thanks again :tada: :tada: :tada: