Issue with torch.batch_norm(): Cannot convert a MPS Tensor to float64

Hi! i’m new to the forums, i hope to formulate the issue correctly:
I am working with the pytorch maskrcnn model which contains a bunch BatchNorm2d layers.
When trying to run the model with my Macbook M1 i run into this error:
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
The problem seems to lie in the torch.batch_norm() function, as i’m pretty sure all tensors i’m using are torch.float32.
These are the easy steps to replicate the issue:

import torch

device = "mps"

test_input = torch.zeros([0, 256, 14, 14], dtype=torch.float32).to(device, dtype=torch.float32)
weight = torch.zeros([256], dtype=torch.float32).to(device, dtype=torch.float32)
bias = torch.zeros([256], dtype=torch.float32).to(device, dtype=torch.float32)
running_mean = torch.zeros([256], dtype=torch.float32).to(device, dtype=torch.float32)
running_var = torch.zeros([256], dtype=torch.float32).to(device, dtype=torch.float32)
momentum=0.1
eps=0.1
training=False

torch.batch_norm(test_input, weight, bias, running_mean, running_var, training, momentum, eps, True)

Thanks in advance and let me know if i can improve the post in any way.