In PyTorch 2.4, when I do a torch.compile (with the Triton matmul backend) of a Resnet model, both matmul inputs are transposed for the resulting Triton kernels generated
Eg:
M = 196
N = 1024
K = 256
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = 1
stride_ak = 196
stride_bk = 1
stride_bn = 256
In PyTorch 2.5, only the second matmul input is transposed.
Eg:
M = 196
N = 1024
K = 256
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = 256
stride_ak = 1
stride_bk = 1
stride_bn = 256
As you can see, stride_am=1 in the first and stride_am=256 in the 2nd.
I was trying to understand which PyTorch commit/PR changed this behavior so that I can better understand this.