Hi! I am wondering the inner implementation about how nn.linear cope with input like (2, 33, 44), it is different from normal (33, 44). I am developing cuda kernel for matmul now.
So I entered my pytorch code, and find this:
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
r"""
Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
Shape:
- Input: :math:`(N, *, in\_features)` N is the batch size, `*` means any number of
additional dimensions
- Weight: :math:`(out\_features, in\_features)`
- Bias: :math:`(out\_features)`
- Output: :math:`(N, *, out\_features)`
"""
if has_torch_function_variadic(input, weight, bias):
return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
return torch._C._nn.linear(input, weight, bias)
But I really can not find where this exist! torch._C._nn.linear
Can anyone tell me where it is? Thank you!!!
This thread points to the CUDA implementations in case that’s helpful.
Thank you very much for your answer!!! But I still have some small questions:
in pytorch/Blas.cpp at c1c9be16c4d0648fc134d04f30c8463575df7ada · pytorch/pytorch · GitHub I can not find function named addmm, or matmul, only can find something like, addmm_out_cuda_impl. I know they should be, close, but how they are related together?
Given that I am implementing linear layer myself, I also wonder, how can I use this in my code, which is, can I really include and run them?
Thank you!!!
I am wondering, if we have a tensor, already transposed, so it will be not continuous. In your link, the addmm will firstly make tensor to be continuous and then compute. So I compare two versions: 1. use built-in nn.linear 2. use .continuous myself And compare the time. Strangely, the second version is much more slow! Why??
Below I attached the code for reproduction:
# This is built-in version
import torch
torch_time = 0
for i in range(50):
torch.manual_seed(i)
tim = 0
start = torch.cuda.Event(enable_timing=True) # the times
end = torch.cuda.Event(enable_timing=True)
start.record()
a = torch.randn(12000, 9000, device="cuda:0")
a1 = a.T
li = torch.nn.Linear(12000, 8000, device="cuda:0")
output = li(a1)
end.record()
torch.cuda.synchronize()
tim = start.elapsed_time(end)
if i > 10:
torch_time += tim
print(torch_time/40)
# 1352.7901000976562
# This is my continuous version.
import torch
torch_time = 0
for i in range(50):
torch.manual_seed(i)
tim = 0
start = torch.cuda.Event(enable_timing=True) # the times
end = torch.cuda.Event(enable_timing=True)
start.record()
a = torch.randn(12000, 9000, device="cuda:0")
a2 = a.T
a1 = a2.contiguous()
li = torch.nn.Linear(12000, 8000, device="cuda:0")
output = li(a1)
end.record()
torch.cuda.synchronize()
tim = start.elapsed_time(end)
if i > 10:
torch_time += tim
print(torch_time/40)
# 3949.7130920410154