How does nn.linear work in cpp for multi dimension input? (torch._C._nn.linear)

Hi! I am wondering the inner implementation about how nn.linear cope with input like (2, 33, 44), it is different from normal (33, 44). I am developing cuda kernel for matmul now.
So I entered my pytorch code, and find this:

def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
    r"""
    Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.

    This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.

    Shape:

        - Input: :math:`(N, *, in\_features)` N is the batch size, `*` means any number of
          additional dimensions
        - Weight: :math:`(out\_features, in\_features)`
        - Bias: :math:`(out\_features)`
        - Output: :math:`(N, *, out\_features)`
    """
    if has_torch_function_variadic(input, weight, bias):
        return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
    return torch._C._nn.linear(input, weight, bias)

But I really can not find where this exist! torch._C._nn.linear
Can anyone tell me where it is? Thank you!!!

This thread points to the CUDA implementations in case that’s helpful.

Thank you very much for your answer!!! But I still have some small questions:
in pytorch/Blas.cpp at c1c9be16c4d0648fc134d04f30c8463575df7ada · pytorch/pytorch · GitHub I can not find function named addmm, or matmul, only can find something like, addmm_out_cuda_impl. I know they should be, close, but how they are related together?

Given that I am implementing linear layer myself, I also wonder, how can I use this in my code, which is, can I really include and run them?

Thank you!!!

I am wondering, if we have a tensor, already transposed, so it will be not continuous. In your link, the addmm will firstly make tensor to be continuous and then compute. So I compare two versions: 1. use built-in nn.linear 2. use .continuous myself And compare the time. Strangely, the second version is much more slow! Why??

Below I attached the code for reproduction:

# This is built-in version
import torch

torch_time = 0
for i in range(50):
    torch.manual_seed(i)
    tim = 0
    start = torch.cuda.Event(enable_timing=True)  # the times
    end = torch.cuda.Event(enable_timing=True)
    start.record()


    a = torch.randn(12000, 9000, device="cuda:0")
    a1 = a.T
    li = torch.nn.Linear(12000, 8000, device="cuda:0")
    output = li(a1)


    end.record()
    torch.cuda.synchronize()
    tim = start.elapsed_time(end)

    if i > 10:
        torch_time += tim

print(torch_time/40)
# 1352.7901000976562
# This is my continuous version. 
import torch



torch_time = 0
for i in range(50):
    torch.manual_seed(i)
    tim = 0
    start = torch.cuda.Event(enable_timing=True)  # the times
    end = torch.cuda.Event(enable_timing=True)
    start.record()

    a = torch.randn(12000, 9000, device="cuda:0")
    a2 = a.T
    a1 = a2.contiguous()
    li = torch.nn.Linear(12000, 8000, device="cuda:0")
    output = li(a1)

    end.record()
    torch.cuda.synchronize()
    tim = start.elapsed_time(end)

    if i > 10:
        torch_time += tim

print(torch_time/40) 
# 3949.7130920410154