Just like torch.mm(a, b), we have a and b on GPU, on global memory now(?), and we use deeper implementation to use shared memory or something by cublas, and then we output a result to global memory, is that correct? Not sure…So I am wondering is there a function like a.device() in python to show, where the data currently is? Thank you!!!
Yes, your data lives in the global memory on the GPU. Shared memory, L1, L2 etc. is used in the CUDA kernels.