Torch is not using gpu full performance

Profile your code and check where the bottleneck is using the native PyTorch profiler or e.g. Nsight Systems.