The running will become slower and slower with epoch continuing, but without error when the mlpmixer model is used

Recently, i trained a mlp-based model in Geforce RTX 4090 with pytorch 2.3.1 and cuda11.8. When i run this script several times simultaneously, the first runed script will strangely become more and more slower with the epoch continuing, and the others will become paused, all without error reported. However, the other models, e.g., CNN, MLP, will work well. I paste the reproduction code in gist “reproduction code for pytorch issue: "The running is paused but without error when the specific model is used" · GitHub”. I am stuck in this problem two days, If anyone could tell me the reason, I would be very grateful. Thanks!

Versions

pytorch: 2.3.1
cuda:11.8
gpu: RTX 4090.
Environment info collected is as followed:
Collecting environment information…
PyTorch version: 2.3.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 professional
GCC version: (MinGW-W64 x86_64-msvcrt-posix-seh, built by Brecht Sanders) 13.1.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.10.11 (tags/v3.10.11:7d4cc5a, Apr 5 2023, 00:38:17) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19045-SP0
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 535.98
cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\bin\cudnn_ops_train64_8.dll
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=3200
DeviceID=CPU0
Family=207
L2CacheSize=16384
L2CacheSpeed=
Manufacturer=GenuineIntel
MaxClockSpeed=3200
Name=13th Gen Intel(R) Core™ i9-13900KS
ProcessorType=3
Revision=

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] torch==2.3.1+cu118
[pip3] torchaudio==2.3.1+cu118
[pip3] torchvision==0.18.1+cu118
[conda] Could not collect