WeightedRandomSampler "number of categories cannot exceed 2^24" only three catagories

Hi all,

I am currently trying to setup a neural network identifying three categorical variables which is severely imbalanced. Therefore I use WeightedRandomSampler so all classes have equal probability. Using this on a small sample of the data, it does exactly what it is supposed to. However, when running the model with the full dataset I keep getting the error: “number of categories cannot exceed 2^24”. The training data consist of 27,000,000 observations of 36 ‘x’ variables and one ‘y’ which is either 0,1 or 2.

I can’t figure out why I get this error and I have tried to implement everything I can find on this forum regarding this error without any luck. Any help would be greatly appreciated.

My code is set up as the following:

### getting data ###
def load_dataset_as_numpy(path):
  dataset = pd.read_csv(path,usecols= selected_columns))
  y_np = dataset.to_numpy()[:, -1]
  x_np = dataset.to_numpy()
  x_np = np.delete(x_np, [-1], axis=1)

  data = TensorDataset(torch.tensor(x_np, dtype=torch.float32),
                                torch.tensor(y_np, dtype=torch.long))
  return data

train_dataset = load_dataset_as_numpy(r'E:\filepath\crsp_train.csv')

### setting up  WeightedRandomSample and DataLoader###

target_list = []
for _, y in train_dataset:
    target_list.append(y)
target_list = torch.tensor(target_list)

class_count = torch.bincount(target_list)

class_weights = 1/class_count

class_weights_all = class_weights[target_list]

print(class_weights_all)

weighted_sampler = WeightedRandomSampler(
    weights=class_weights_all,
    num_samples=len(class_weights_all),
    replacement=True
)

loader_train = DataLoader(train_dataset, batch_size=batc_size, shuffle=False, sampler=weighted_sampler)


### NN model ### 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'using {device} device')



class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden1)
        self.layer2 = nn.Linear(hidden1, hidden2)
        self.layer3 = nn.Linear(hidden2, hidden3)
        self.out = nn.Linear(hidden3, 3)

        self.relu = nn.ReLU()
        self.drop = nn.Dropout(p=p2)
        self.batchnorm1 = nn.BatchNorm1d(hidden1)
        self.batchnorm2 = nn.BatchNorm1d(hidden2)
        self.batchnorm3 = nn.BatchNorm1d(hidden3)

    def forward(self, x):
        x = self.drop(x)

        x = self.layer1(x)
        x = self.batchnorm1(x)
        x = self.relu(x)
        x = self.drop(x)

        x = self.layer2(x)
        x = self.batchnorm2(x)
        x = self.relu(x)
        x = self.drop(x)

        x = self.layer3(x)
        x = self.batchnorm3(x)
        x = self.relu(x)
        x = self.drop(x)


        x = self.out(x)

        return x

model = Model().to(device)
print(model)

def multi_acc(y_hat, y):
    y_pred_softmax = torch.log_softmax(y_hat, dim=1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim=1)

    correct_pred = (y_pred_tags == y).float()
    acc = correct_pred.sum() / len(correct_pred)

    acc = torch.round(acc * 100)

    return acc

test_stats = {
    'loss': [],
    "acc": []
}
train_stats = {
    'loss': [],
    "acc": []
}

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

def train(dataloader, model, loss_fn, optimizer, multi_acc):
    model.train()
    train_loss = 0
    train_acc = 0

    for i,(x,y) in enumerate(dataloader):
     x,y = x.to(device), y.to(device)

     y_hat = model(x)
     loss = loss_fn(y_hat,y)
     train_loss += loss.item()
     acc = multi_acc(y_hat, y)
     train_acc += acc.item()

     optimizer.zero_grad()
     loss.backward()
     optimizer.step()


    num_batches = len(dataloader)
    train_loss = train_loss / num_batches
    train_acc = train_acc/ num_batches

    train_stats['loss'].append(train_loss)
    train_stats['acc'].append(train_acc)

    #print(f'train RMSE: {train_loss}')
    print(
        f'Epoch {epoch + 1:03}: | Train Loss: {train_loss:.5f} | Train Acc: {train_acc:.3f}| ')


for epoch in range(epochs):
    # print(f"Epoch {epoch+1}:")
    start_time = time.time()

    train(loader_train, model, loss_fn, optimizer, multi_acc)

    print("--- %s seconds ---" % (time.time() - start_time))

The full error message I get is:

Traceback (most recent call last):
  File "C:\Users\swlli\NeuralNetworkIndicator.py", line 251, in <module>
    train(loader_train, model, loss_fn, optimizer, multi_acc)
  File "C:\Users\swlli\NeuralNetworkIndicator.py", line 200, in train
    for i,(x,y) in enumerate(dataloader):
  File "E:\Venv\lib\site-packages\torch\utils\data\dataloader.py", line 628, in __next__
    data = self._next_data()
  File "E:\Venv\lib\site-packages\torch\utils\data\dataloader.py", line 670, in _next_data
    index = self._next_index()  # may raise StopIteration
  File "E:\Venv\lib\site-packages\torch\utils\data\dataloader.py", line 618, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
  File "E:\Venv\lib\site-packages\torch\utils\data\sampler.py", line 254, in __iter__
    for idx in self.sampler:
  File "E:\Venv\lib\site-packages\torch\utils\data\sampler.py", line 203, in __iter__
    rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
RuntimeError: number of categories cannot exceed 2^24

I don’t see any obvious errors in your code. Could you check the len of class_weights_all?

Thank you for the reply, the length of class_weights_all is 27.000.000 (same length as number of observations in the training data).

The issue is raised in these lines of code since float32 is used in the multinomial operation and needs consecutive integer values, which are defined for <2**24 for float32. It seems float32 is strictly defined and casting inputs to float64 doesn’t solve the error. Could you create a feature request for your use case on GitHub and explain your use case a bit more, please, as this limitation might be relaxed.

1 Like

it too have similar error. As my number of class instance overshot the limit…
@ptrblck

For anyone who lands here from a Google search: this is an old issue with WeightedRandomSampler. There is a workaround on github for the case when the number of samples is small: CUDA multinomial is limited to 2^24 categories · Issue #2576 · pytorch/pytorch · GitHub. Does not apply to the OP, but works with a reasonable epoch size (hundreds of thousands, as opposed to tens of millions)