Cannot convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first

Hey, I am getting TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

I looked into forum but could not resolve this.

Code:

class LSTNet(nn.Module):
    
    def __init__(self):
        super(LSTNet, self).__init__()
        self.num_features = torch.tensor(5).cuda()
        self.conv1_out_channels = torch.tensor(32).cuda() 
        self.conv1_kernel_height = torch.tensor(7).cuda()
        self.recc1_out_channels = torch.tensor(64).cuda()
        self.skip_steps = torch.stack([torch.tensor(4), torch.tensor(24)], dim=0).cuda()
        self.skip_reccs_out_channels = torch.stack([torch.tensor(4), torch.tensor(4)], dim=0).cuda()
        self.prediction_len = torch.tensor(30).cuda() # Future prediction values
        self.output_out_features = torch.tensor(5*self.prediction_len).cuda()
        self.ar_window_size = torch.tensor(7).cuda()
        self.dropout = nn.Dropout(p = 0.2)
       
        
        self.conv1 = nn.Conv2d(1, self.conv1_out_channels, 
                               kernel_size=(self.conv1_kernel_height, self.num_features))
        self.recc1 = nn.GRU(self.conv1_out_channels, self.recc1_out_channels, batch_first=True)
        self.skip_reccs = {}
        for i in range(len(self.skip_steps)):
            self.skip_reccs[i] = nn.GRU(self.conv1_out_channels, self.skip_reccs_out_channels[i], batch_first=True)
        self.output_in_features = self.recc1_out_channels + np.dot(self.skip_steps, self.skip_reccs_out_channels)
        self.output = nn.Linear(self.output_in_features, self.output_out_features) # prediction
        if self.ar_window_size > 0:
            self.ar = nn.Linear(self.ar_window_size, 1)
        
    def forward(self, X):
        """
        Parameters:
        X (tensor) [batch_size, time_steps, num_features]
        """
        batch_size = X.size(0)
        
        # Convolutional Layer
        C = X.unsqueeze(1) # [batch_size, num_channels=1, time_steps, num_features]
        C = F.relu(self.conv1(C)) # [batch_size, conv1_out_channels, shrinked_time_steps, 1]
        C = self.dropout(C)
        C = torch.squeeze(C, 3) # [batch_size, conv1_out_channels, shrinked_time_steps]
        
        # Recurrent Layer
        R = C.permute(0, 2, 1) # [batch_size, shrinked_time_steps, conv1_out_channels]
        out, hidden = self.recc1(R) # [batch_size, shrinked_time_steps, recc_out_channels]
        R = out[:, -1, :] # [batch_size, recc_out_channels]
        R = self.dropout(R)
        #print(R.shape)
        
        # Skip Recurrent Layers
        shrinked_time_steps = C.size(2)
        for i in range(len(self.skip_steps)):
            skip_step = self.skip_steps[i]
            skip_sequence_len = shrinked_time_steps // skip_step
            # shrinked_time_steps shrinked further
            S = C[:, :, -skip_sequence_len*skip_step:] # [batch_size, conv1_out_channels, shrinked_time_steps]
            S = S.view(S.size(0), S.size(1), skip_sequence_len, skip_step) # [batch_size, conv1_out_channels, skip_sequence_len, skip_step=num_skip_components]
            # note that num_skip_components = skip_step
            S = S.permute(0, 3, 2, 1).contiguous() # [batch_size, skip_step=num_skip_components, skip_sequence_len, conv1_out_channels]
            S = S.view(S.size(0)*S.size(1), S.size(2), S.size(3))  # [batch_size*num_skip_components, skip_sequence_len, conv1_out_channels]
            out, hidden = self.skip_reccs[i](S) # [batch_size*num_skip_components, skip_sequence_len, skip_reccs_out_channels[i]]
            S = out[:, -1, :] # [batch_size*num_skip_components, skip_reccs_out_channels[i]]
            S = S.view(batch_size, skip_step*S.size(1)) # [batch_size, num_skip_components*skip_reccs_out_channels[i]]
            S = self.dropout(S)
            R = torch.cat((R, S), 1) # [batch_size, recc_out_channels + skip_reccs_out_channels * num_skip_components]
            #print(S.shape)
        #print(R.shape)
        
        # Output Layer
        O = F.relu(self.output(R)) # [batch_size, output_out_features=1]
        
        if self.ar_window_size > 0:
            # set dim3 based on output_out_features
            AR = X[:, -self.ar_window_size:, 3:4] # [batch_size, ar_window_size, output_out_features=1]
            AR = AR.permute(0, 2, 1).contiguous() # [batch_size, output_out_features, ar_window_size]
            AR = self.ar(AR) # [batch_size, output_out_features, 1]
            AR = AR.squeeze(2) # [batch_size, output_out_features]
            O = O + AR
        
        return O

If keep self.skip_steps & self.skip_reccs_out_channels on CPU the error is gone, but then while training I have the error:

RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cuda:0 and parameter tensor at cpu

Did you try to to call .cpu() on the tensor first as suggested in the error message?

Yes, I did tried .cpu() for self.skip_steps & self.skip_reccs_out_channels.
As I stated, I then get the error while training

RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cuda:0 and parameter tensor at CPU

I also used model.cuda()
See here

for n, p in model.named_parameters():
    print(p.device, '', n)

cuda:0 conv1.weight
cuda:0 conv1.bias
cuda:0 recc1.weight_ih_l0
cuda:0 recc1.weight_hh_l0
cuda:0 recc1.bias_ih_l0
cuda:0 recc1.bias_hh_l0
cuda:0 output.weight
cuda:0 output.bias
cuda:0 ar.weight
cuda:0 ar.bias

I don’t fully understand where exactly you are trying to call .numpy() on a CUDATensor so could you explain your use case a bit more and post a minimal, executable code snippet to reproduce the issue, please?

After defining the model class. (Using .cpu() for self.skip_steps & self.skip_reccs_out_channels)

train_dataset = MarketDataset(train_data, history_len=history_len, prediction_len = prediction_len)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

model = LSTNet()
model.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
train_loss_list = []
model_path = '/content/drive/MyDrive/azid/LSTNet/'
for epoch in tqdm(range(epochs)):
    
    epoch_loss_train = 0
    for i, batch in tqdm(enumerate(train_data_loader, start=1), 
                         leave=False, desc="Train", total=len(train_data_loader)):
            
        X, Y = batch
        X = X.to(device)
        Y = Y.to(device)
        optimizer.zero_grad()
        Y_pred = model(X)
        Y_pred = Y_pred.view(Y.shape)
        loss = criterion(Y_pred, Y)
        loss.backward()
        optimizer.step()

        with open('/content/drive/MyDrive/azid/LSTNet/Log/Running-Loss.txt', 'a+') as file:
            file.write(f'{loss.item()}\n')
        epoch_loss_train += loss.item()
        
    epoch_loss_train = epoch_loss_train / len(train_data_loader)
    train_loss_list.append(epoch_loss_train)
    
    with open('/content/drive/MyDrive/azid/LSTNet/Log/Epoch-Loss.txt', 'a+') as file:
        file.write(f'{epoch_loss_train}\n')
    
    # Save model at each epoch
    torch.save(model.state_dict(), f'{model_path}-{epoch}.pth')

    metrics = {"train/train_loss": loss}
    # wandb
    wandb.log(metrics)

Error:

/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:52: UserWarning:

__rfloordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-67-be5135f098c0> in <module>()
     11         Y = Y.to(device)
     12         optimizer.zero_grad()
---> 13         Y_pred = model(X)
     14         Y_pred = Y_pred.view(Y.shape)
     15         loss = criterion(Y_pred, Y)

3 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
    941         if batch_sizes is None:
    942             result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
--> 943                              self.dropout, self.training, self.bidirectional, self.batch_first)
    944         else:
    945             result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,

RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cuda:0 and parameter tensor at CPU

I’ll try to get some sharable data for reproducibility.

The numpy operation is in this line

self.output_in_features = self.recc1_out_channels + np.dot(self.skip_steps, self.skip_reccs_out_channels)

I used .cpu() to resolve it, but how do I resolve this now

RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cuda:0 and parameter tensor at CPU

Ok. I resolved it by moving each layer in init to cuda by adding .cuda() individually.

Good to hear it’s resolved, but note that calling .cuda() on each layer in the __init__ is a workaround but not a proper fix.
Currently, self.skip_reccs is created as a plain Python dict and will thus not register the modules properly (which also explains the device mismatch you were seeing, since these layers were never pushed to the GPU). Besides the device mismatch, model.parameters() will also not return the parameters of these layers, so your optimizer will never train them.
Use nn.ModuleDict instead to fix these issues properly and remove the cuda() calls in the __init__.

Thanks a lot. This was very useful.