Segmentation Fault for loss.backward() with batch_size > 1 on M1 Mac

Hello! I have the following simple model and training procedure, but when I run loss.backward() on an M1 Mac with any batch size greater than 1, it produces a segmentation fault.

class MLP(nn.Module):
    def __init__(self, classes):
        super().__init__()
        self.classes = classes
        
        self.ff1 = nn.Linear(2, 20)
        self.ff2 = nn.Linear(20, classes + 1)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        out = self.relu(self.ff1(x))
        out = self.ff2(out)
        
        return out
def train(dataloader, models, loss_fn, optimizer):
    for model in models:
        model.train()
    
    total_loss = 0
    batches = 0
    
    for batch, (X, Y) in enumerate(dataloader):
        loss = 0
        optimizer.zero_grad()
        for group in range(len(models)):
            current_y = torch.where(Y[:, 0] == group, Y[:, 1], models[group].classes)
            
            preds = models[group](X)
            loss = loss + loss_fn(preds, current_y)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss
        batches += 1
        
        if batch % 1000 == 0:
            print(f'Current Loss: {total_loss / batches}')
            total_loss = 0
            batches = 0
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset)

group_1 = MLP(3)
group_2 = MLP(3)
lr = 1e-4
epochs = 15

loss_fn = nn.CrossEntropyLoss(reduction='sum')

params = list(group_1.parameters()) + list(group_2.parameters())

optimizer = torch.optim.SGD(params, lr)

for e in range(epochs):
    print(f'Epoch {e}:')
    train(train_dataloader, [group_1, group_2], loss_fn, optimizer)

Interestingly, this code does not produce a segmentation error when run on an Ubuntu server. This leads me to believe that it may be an issue related to the M1 chip. Advice on how to resolve this issue would be greatly appreciated.