Problem with setting param.requires_grad=False for BatchNorm layers only when using mps device on m1

Hello! I have been struggling with this weird error for a few days now and I can’t seem to find a solution, the code I provide works perfectly when using cpu but when using the mps device on a mbp 14 it throws an error. Additionally it definitely is not a memory issue as the code runs if all params are set to trainable it breaks only when batchnorm are frozen and everything else isn’t

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models.resnet import resnet50, ResNet50_Weights

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
for name, param in model.named_parameters():
    if "bn" in name or "batchnorm" in name.lower():
        param.requires_grad = False
        
# Adding a custom classification head
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(num_ftrs, 1024),
    nn.ReLU(inplace=True),
    nn.Dropout(0.2),
    nn.Linear(1024, 512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.2),
    nn.Linear(512, 3),
)

# Then I have a standard training loop stored in a class
  trainer = ModelTrainer(model, train_loader, val_loader, num_epochs=1)
  model = trainer.train()
  history = trainer.history()
Error:

Using mps device
Epoch 1/1
----------
train:   0%|                                                                           | 0/38 [00:00<?, ? batch/s]2023-04-07 17:41:27.555 python[8858:237219] *** Terminating app due to uncaught exception 'NSInvalidArgumentException', reason: '*** -[__NSDictionaryM setObject:forKeyedSubscript:]: key cannot be nil'
*** First throw call stack:
(
        0   CoreFoundation                      0x00000001b2a78418 __exceptionPreprocess + 176
        1   libobjc.A.dylib                     0x00000001b25c2ea8 objc_exception_throw + 60
        2   CoreFoundation                      0x00000001b2b5dcc4 -[__NSCFString characterAtIndex:].cold.1 + 0
        3   CoreFoundation                      0x00000001b2b6ae4c -[__NSDictionaryM setObject:forKeyedSubscript:].cold.2 + 0
        4   CoreFoundation                      0x00000001b29c5a0c -[__NSDictionaryM setObject:forKeyedSubscript:] + 928
        5   libtorch_cpu.dylib                  0x0000000157ecd280 _ZN2at6native23batch_norm_backward_mpsERKNS_6TensorES3_RKN3c108optionalIS1_EES8_S8_S8_S8_bdNSt3__15arrayIbLm3EEE + 4380
        6   libtorch_cpu.dylib                  0x000000015463284c _ZN2at4_ops26native_batch_norm_backward10redispatchEN3c1014DispatchKeySetERKNS_6TensorES6_RKNS2_8optionalIS4_EESA_SA_SA_SA_bdNSt3__15arrayIbLm3EEE + 200
        7   libtorch_cpu.dylib                  0x00000001563d99b4 _ZN3c104impl28wrap_kernel_functor_unboxed_INS0_6detail24WrapFunctionIntoFunctor_INS_26CompileTimeFunctionPointerIFNSt3__15tupleIJN2at6TensorES8_S8_EEENS_14DispatchKeySetERKS8_SC_RKNS_8optionalIS8_EESG_SG_SG_SG_bdNS5_5arrayIbLm3EEEEXadL_ZN5torch8autograd12VariableType12_GLOBAL__N_126native_batch_norm_backwardESA_SC_SC_SG_SG_SG_SG_SG_bdSI_EEEES9_NS_4guts8typelist8typelistIJSA_SC_SC_SG_SG_SG_SG_SG_bdSI_EEEEESJ_E4callEPNS_14OperatorKernelESA_SC_SC_SG_SG_SG_SG_SG_bdSI_ + 2392
        8   libtorch_cpu.dylib                  0x00000001546324ec _ZN2at4_ops26native_batch_norm_backward4callERKNS_6TensorES4_RKN3c108optionalIS2_EES9_S9_S9_S9_bdNSt3__15arrayIbLm3EEE + 468
        9   libtorch_cpu.dylib                  0x00000001560bcce8 _ZN5torch8autograd9generated24NativeBatchNormBackward05applyEONSt3__16vectorIN2at6TensorENS3_9allocatorIS6_EEEE + 884
        10  libtorch_cpu.dylib                  0x00000001570a2a50 _ZN5torch8autograd4NodeclEONSt3__16vectorIN2at6TensorENS2_9allocatorIS5_EEEE + 120
        11  libtorch_cpu.dylib                  0x000000015709983c _ZN5torch8autograd6Engine17evaluate_functionERNSt3__110shared_ptrINS0_9GraphTaskEEEPNS0_4NodeERNS0_11InputBufferERKNS3_INS0_10ReadyQueueEEE + 2932
        12  libtorch_cpu.dylib                  0x00000001570986e0 _ZN5torch8autograd6Engine11thread_mainERKNSt3__110shared_ptrINS0_9GraphTaskEEE + 640
        13  libtorch_cpu.dylib                  0x00000001570973c4 _ZN5torch8autograd6Engine11thread_initEiRKNSt3__110shared_ptrINS0_10ReadyQueueEEEb + 336
        14  libtorch_python.dylib               0x000000014867df38 _ZN5torch8autograd6python12PythonEngine11thread_initEiRKNSt3__110shared_ptrINS0_10ReadyQueueEEEb + 112
        15  libtorch_cpu.dylib                  0x00000001570a5bb0 _ZNSt3__1L14__thread_proxyINS_5tupleIJNS_10unique_ptrINS_15__thread_structENS_14default_deleteIS3_EEEEMN5torch8autograd6EngineEFviRKNS_10shared_ptrINS8_10ReadyQueueEEEbEPS9_aSC_bEEEEEPvSJ_ + 76
        16  libsystem_pthread.dylib             0x00000001b291e06c _pthread_start + 148
        17  libsystem_pthread.dylib             0x00000001b2918e2c thread_start + 8
)
libc++abi: terminating with uncaught exception of type NSException
[1]    8858 abort      /Users/dimitardimitrov/miniconda3/envs/pytorch2/bin/python 
/Users/dimitardimitrov/miniconda3/envs/pytorch2/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Here is my environment information:

Versions

PyTorch version: 2.0.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.0 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: Could not collect
Libc version: N/A

Python version: 3.10.10 (main, Mar 21 2023, 13:41:05) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-13.0-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] torch==2.0.0
[pip3] torchaudio==2.0.0
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.15.0
[conda] numpy 1.23.5 py310hb93e574_0
[conda] numpy-base 1.23.5 py310haf87e8b_0
[conda] pytorch 2.0.0 py3.10_0 pytorch
[conda] torchaudio 2.0.0 py310_cpu pytorch
[conda] torchsummary 1.5.1 pypi_0 pypi
[conda] torchvision 0.15.0 py310_cpu pytorch