LSTM "aten::mkldnn_rnn_layer" not implemented for CUDA backend

Hello everyone,

I am using the LSTM class from pytorch to process some features from a CNN. When I run it in the CPU I have no error and the results are good. But when I try running it on the GPU I am having the following error:

NotImplementedError: Could not run 'aten::mkldnn_rnn_layer' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::mkldnn_rnn_layer' is only available for these backends: [CPU, Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].

CPU: registered at aten/src/ATen/RegisterCPU.cpp:31034 [kernel]
Meta: registered at /dev/null:241 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:491 [backend fallback]
Functionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:280 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:63 [backend fallback]
AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradHIP: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradMPS: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradIPU: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradXPU: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradVE: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradLazy: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradMeta: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradMTIA: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:17472 [autograd kernel]
Tracer: registered at ../torch/csrc/autograd/generated/TraceType_2.cpp:16726 [kernel]
AutocastCPU: registered at ../aten/src/ATen/autocast_mode.cpp:492 [kernel]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:354 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:815 [backend fallback]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1073 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:210 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:152 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:487 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]

Here is my model:

class RecurrentCNN(nn.Module):

    """
    Adaptation of the recurrent neural network with 2 LSTM blocks presented in multiple papers.
    - "Camera Configuration Models for Machine Vision Based Force Estimation in Robot-Assisted Soft Body Manipulation" by Wenjun Liu et al. (doi: https://doi.org/10.1109/ISMR48347.2022.9807587)
    - "A recurrent convolutional neural network approach for sensorless force estimation in robotic surgery" by Arturo Marban et al. (doi: https://doi.org/10.1016/j.bspc.2019.01.011)
    """
    def __init__(self, num_layers: int = 18, pretrained: bool = True, include_depth: bool = True, 
                 att_type: str = None, embed_dim: int = 512, hidden_size: int = 12, num_blocks: int = 2,
                 include_rs: bool = True):
        
        super(RecurrentCNN, self).__init__()
        self.embed_dim = embed_dim
        self.hidden_size = hidden_size
        self.num_blocks = num_blocks
        final_ch = 512 if num_layers == 18 else 2048

        self.encoder = ResnetEncoder(num_layers=num_layers, pretrained=pretrained, include_depth=include_depth, att_type=att_type)
        self.linear = nn.Linear(final_ch * 8 * 8, embed_dim)
        self.lstm1 = nn.LSTM(input_size=embed_dim, hidden_size=embed_dim, num_layers=num_blocks, batch_first=True, dropout=0.)
        self.lstm2 = nn.LSTM(input_size=embed_dim, hidden_size=hidden_size, num_layers=num_blocks, batch_first=True, dropout=0.)
        self.fc = nn.Linear(hidden_size, 3)
    
    def forward(self, imgs: torch.Tensor, robot_state: torch.Tensor = None) -> torch.Tensor:
        batch_size = imgs[0].shape[0]
        rec_size = len(imgs)

        x = torch.zeros(batch_size, rec_size, self.embed_dim) 

        for i in range(batch_size):
            inp = torch.cat([img[i].unsqueeze(0) for img in imgs], dim=0)
            out = self.encoder(inp)
            out = out.view(rec_size, -1)
            x[i] = self.linear(out)

        if robot_state is not None:
            rs_size = robot_state.shape[-1]
            padding_dim = (512 - rs_size - 1)
            padded_state = F.pad(robot_state, (1, padding_dim), 'constant', 0)
            x = torch.cat([x, padded_state], dim=1)

        x = x.reshape(batch_size, -1, self.embed_dim) # reshape the input in case there is a mismatch

        # recurrent part
        h_0 = torch.autograd.Variable(torch.randn(self.num_blocks, batch_size, self.embed_dim).cuda())
        c_0 = torch.autograd.Variable(torch.randn(self.num_blocks, batch_size, self.embed_dim).cuda())
        x, (h_n, c_n) = self.lstm1(x, (h_0, c_0))
        x, _ = self.lstm2(x, (h_n, c_n))
        x = x[:, -1, :]
        pred = self.fc(x)

        return pred

The ResNet encoder is the custom Resnet 18 or 50 from the torchvision.models library. Any idea where the error is coming from? Thank you for your help.

Based on the error message it seems you are trying to run mkldnn_rnn_layer, which uses the MKL backend, on a GPU.
Could you explain if and how you are forcing to use MKL? Did you transform the inputs via to_mkldnn() or did you use a backend-specific context manager?

I am not sure about that layer. I am just using the model I wrote on my original message. Is the mkldnn_rnn_layer contained inside the base LSTM pytorch model? I just followed the pytorch documentation to define the LSTM block.

Could you post a minimal and executable code snippet to reproduce the issue, please?

Here is a small snippet to run the code with the model in my first message:

device = torch.device("cuda")
model = RecurrentCNN(num_layers=50, pretrained=False, include_depth=True,
                     att_type=None, embed_dim=512, hidden_size=12, num_blocks=2)

model.to(device)
imgs = [torch.randn(2, 4, 256, 256).to(device) for _ in range(5)]

out = model(imgs)

print(out.shape)

ResnetEncoder is undefined. Make sure you can copy/paste the code and it runs in a new terminal in your setup.

I have debugged the code a bit and the problem is coming from the output of my resnet encoder. Not sure why, but when I run the LSTM block alone using the GPU there is no error, but when I run the code with the whole pipeline encoder+LSTM I have the error.

This is the ResNet encoder I am using, so you can have a try from your side:

import torch
import torch.nn as nn
from models.utils import FcBlock
import numpy as np
import torchvision.models as models
import torch.utils.model_zoo as model_zoo
from models.bam import BAM
import torch.nn.functional as F
from typing import List

class ResNetMultiImageInput(models.ResNet):

    def __init__(self, block, layers, num_classes=1000, num_input_images=1,
                 input_channel=3, att_type=None):
        super(ResNetMultiImageInput, self).__init__(block, layers)
        self.inplanes = 64
        self.conv1 = nn.Conv2d(
            num_input_images * input_channel, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=False)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        if att_type=='BAM':
            self.bam1 = BAM(64*block.expansion)
            self.bam2 =  None # BAM(128*block.expansion)
            self.bam3 = None # BAM(256*block.expansion)
        else:
            self.bam1, self.bam2, self.bam3 = None, None, None

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1,
                            include_depth=True, att_type=None):
    """Constructs a ResNet model.

    Args:
        num_layers (int): Number of resnet layers. Must be 18 or 50
        pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Defaults to False.
        num_input_images (int, optional): Number of frames stacked as input. Defaults to 1.
    """
    assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
    blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
    block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
    model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images, 
                                  input_channel=4 if include_depth else 3, att_type=att_type)

    if pretrained:
        loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
        loaded['conv1.weight'] = torch.cat(
            [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
        model.load_state_dict(loaded)
    return model


class ResnetEncoder(nn.Module):
    """Pytorch module for a resnet encoder
    """
    def __init__(self, num_layers, pretrained, num_input_images=1, include_depth=True, att_type=None):
        super(ResnetEncoder, self).__init__()

        self.att_type = att_type

        self.num_ch_enc = np.array([64, 64, 128, 256, 512])

        resnets = {
            18: models.resnet18,
            34: models.resnet34,
            50: models.resnet50,
            101: models.resnet101,
            152: models.resnet152
        }

        if num_layers not in resnets:
            raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
        
        if include_depth or att_type is not None:
            self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images, include_depth, att_type)
        else:
            self.encoder = resnets[num_layers](pretrained)
        
        if num_layers > 34:
            self.num_ch_enc[1:] *= 4
        
    def forward(self, input_image):
        x = input_image

        x = self.encoder.conv1(x)
        x = self.encoder.bn1(x)
        x = self.encoder.relu(x)
        x = self.encoder.maxpool(x)
        x = self.encoder.layer1(x)

        if self.att_type is not None:
            if not self.encoder.bam1 is None:
                x = self.encoder.bam1(x)

        x = self.encoder.layer2(x)
        if self.att_type is not None:
            if not self.encoder.bam2 is None:
                x = self.encoder.bam2(x)

        x = self.encoder.layer3(x)
        if self.att_type is not None:
            if not self.encoder.bam3 is None:
                x = self.encoder.bam3(x)

        x = self.encoder.layer4(x)

        return x

I had this error too. my code gave this error after upgraded to torch 2

LSTM “aten::mkldnn_rnn_layer” not implemented for CUDA backend

Not sure if this will help but i downgraded to torch 1.13.1 and it works. Hope this helps someone.

This worked! Thank you.

1 Like

I am also having the same issue when using LSTM modules in PyTorch 2.0. It gives me this error:

NotImplementedError: Could not run 'aten::mkldnn_rnn_layer' with arguments from the 'CUDA' backend.

Hi @amir_mir93, downgrading torch to version 1.13.1 as suggested by Saika, worked for me.

Has anyone found an option that does not include downgrading? I am running on a shared resource and am unable to downgrade.

A likely cause is that the model and input is on different devices. Make sure that they are on the same device and also that no tensors have been created with the legacy constructor torch.Tensor