I was reading some tutorial about torch.compile
to understand what is happening under the hood and how it converts the pytorch codes to optimized kernels.
Can someone tell me, what are the topics should I know more about?
For example, here it says that we can optimize the code and substitute the code with 2 reads and 2 writes with 1 read and one write.
import torch
def fn(x, y):
a = torch.cos(x).cuda()
b = torch.sin(y).cuda()
return a + b
new_fn = torch.compile(fn, backend="inductor")
input_tensor = torch.randn(10000).to(device="cuda:0")
a = new_fn(input_tensor, input_tensor)
So, can somebody tell me how we can categorize the optimization ways? How can I know more about this topic? Do we have a equivalent optimized kernel for each pytorch command like nn.Linear()
? Or there may be somewhere to apply custom optimization?