I’ve run into a similar issue but I’m out of ideas. (On AWS with a g4dn.2xlarge instance) An almost identical code that I had seemed to work fine.
I also tried to run with a batch size of 1 but still seems to fail. PS: This code works completely fine if not using a GPU.
FIX: For some reason this was an issue with pytorch 1.8.0. I looked at this post RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)` while running fine on the CPU - #13 by ptrblck and tried to downgrade pytorch and it worked fine
