Reproducibility breaks down with weighted Cross Entropy loss

Hello, the following code ceases to be reproducible when the weights in cross entropy are non-integers. Here’s the example:

import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader


h, w, in_ch, out_ch = 32, 32, 3, 5
class Dtst(Dataset):
    def __init__(self, N=20):
        self.X = [torch.randn([in_ch, h, w], dtype=torch.float32) for _ in range(N)]
        self.Y = [torch.randint(low=0, high=out_ch, size=(h,w), dtype=torch.int64) for _ in range(N)]
        
    def __getitem__(self, ix):
        return self.X[ix], self.Y[ix]
    
    def __len__(self):
        return len(self.Y)


class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Conv2d(in_channels=in_ch, out_channels=10, kernel_size=3, padding=1)
        self.drop = nn.Dropout2d(p=0.1)
        self.layer2 = nn.Conv2d(in_channels=10, out_channels=out_ch, kernel_size=3, padding=1)

    def forward(self, x):
        out = self.layer2(self.drop(self.layer1(x)))
        return out

seed = 4
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

dtst = Dtst()
model = Network()

device = 'cuda'
model.to(device)
class_weights = ((torch.arange(out_ch)+1).type(torch.FloatTensor)**0.5).to(device)

loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
opt = torch.optim.Adam(model.parameters())

preds_dict = dict()
for e in range(1500):
    dtldr = DataLoader(dtst, batch_size=4)
    for x,y in dtldr:
        preds = model(x.to(device))
        loss = loss_fn(preds, y.to(device))
        loss.backward()
        opt.step()

        preds_argmax = preds.argmax(dim=1).flatten()
        preds_dict.update(Counter(preds_argmax.tolist()))

print(sorted(preds_dict.items(), key=lambda x: x[1]))
print(model.layer1.weight.data.norm(2).item())

It’s a very simple network with a very basic Dataset, and a simple train loop.
This code is not reproducible. But when I remove the (**0.5) part from the class_weights it becomes reproducible. I.e., if the class weight values are actual floats, not integers cast to floats, then the code is not reproducible.

Also, the problem exists only on cuda. If the device is set to ‘cpu’, the code is reproducible again.
I run this on Ubuntu 18. My environment is the following:
pytorch 1.6.0
cudatoolkit 10.1.243
numpy 1.19.1

1 Like

I’ve added this same issue on github as well, seems like there it is getting more attention. Here’s the link https://github.com/pytorch/pytorch/issues/46024