How to handle imbalanced classes

I don’t know, but you could use this code as a small example to see how a WeightedRandomSampler is used to create balanced batches for a binary use case.

You should pass the sampler to the DataLoader using the corresponding targets.
I.e. if the sampler used the training targets to calculate its weights, it should be used together with the training dataset in the training DataLoader.

i am refering your code only for multiclass .
for binary it gives approx equal samples, but for multiclass its not giving equal samples.
i guess for multiclass it gives more samples for class having less weights.

That’s not the case and you can easily extend the example to a multiclass use case, which still yields balanced examples:

numDataPoints = 10000
data_dim = 5
bs = 1000

# Create dummy data with class imbalance 9 to 1
data = torch.randn(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.5), dtype=np.int32),
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32),
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32) * 2,
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32) * 3,
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32) * 4,
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32) * 5))

class_sample_count = np.array(
    [len(np.where(target == t)[0]) for t in np.unique(target)])
print(class_sample_count)
# [5000 1000 1000 1000 1000 1000]

weight = 1. / class_sample_count
print(weight)
# [0.0002 0.001  0.001  0.001  0.001  0.001 ]
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
    print("batch index: {}, class count: {}".format(
        i, [len((target == i).nonzero()) for i in range(len(target.unique()))]))
batch index: 0, class count: [170, 170, 159, 176, 144]
batch index: 1, class count: [163, 166, 177, 155, 177]
batch index: 2, class count: [187, 171, 158, 153, 175]
batch index: 3, class count: [157, 153, 188, 162, 187]
batch index: 4, class count: [158, 166, 161, 167, 182]
batch index: 5, class count: [176, 168, 158, 169, 158]
batch index: 6, class count: [160, 159, 159, 169, 182]
batch index: 7, class count: [165, 158, 180, 154, 169]
batch index: 8, class count: [164, 160, 174, 168, 151]
batch index: 9, class count: [157, 194, 157, 169, 174]

could you check with smaller batch size because i cant fit larger batch size.
that’s only difference i can see in your and mine code.

with larger batch size (8000) i can get approx equal samples.

OrderedDict([(0, 172), (1, 170), (2, 183), (3, 165), (4, 168), (5, 192), (6, 176), (7, 168), (8, 187), (9, 174), (10, 172), (11, 188), (12, 186), (13, 176), (14, 175), (15, 139), (16, 178), (17, 159), (18, 162), (19, 168), (20, 177), (21, 176), (22, 160), (23, 184), (24, 196), (25, 189), (26, 183), (27, 184), (28, 178), (29, 201), (30, 184), (31, 160), (32, 196), (33, 182), (34, 197), (35, 179), (36, 175), (37, 176), (38, 175), (39, 179), (40, 194), (41, 178), (42, 177), (43, 194), (44, 168)])

Smaller batch sizes will create more noise, since the weighted sampling is a random process.
If you collect the batches and check the stats the epoch would still show a balanced usage:

numDataPoints = 1000
data_dim = 5
bs = 5

# Create dummy data with class imbalance 9 to 1
data = torch.randn(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.5), dtype=np.int32),
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32),
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32) * 2,
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32) * 3,
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32) * 4,
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32) * 5))

class_sample_count = np.array(
    [len(np.where(target == t)[0]) for t in np.unique(target)])
print(class_sample_count)
# [5000 1000 1000 1000 1000 1000]

weight = 1. / class_sample_count
print(weight)
# [0.0002 0.001  0.001  0.001  0.001  0.001 ]
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

freqs = np.zeros(len(target.unique()))
for i, (data, t) in enumerate(train_loader):
    f = [len((t == i).nonzero()) for i in range(len(target.unique()))]
    print("batch index: {}, class count: {}".format(i, f))
    freqs += np.array(f)

print(freqs)
# [164. 185. 185. 139. 159. 168.]

IndexError with WeightedRandomSampler in PyTorch DataLoader for LSTM Model

Hello @ptrblck

I’m working on an LSTM model for anomaly detection using PyTorch. I’m trying to use a WeightedRandomSampler to handle class imbalance in my dataset. However, I’m encountering an IndexError when iterating over the DataLoader. Here is my code::

CODE:

# Normalize the features
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X.values)

# Prepare data for LSTM
def create_sequences(X, y, time_steps=1):
    Xs, ys = [], []
    for i in range(len(X) - time_steps):
        Xs.append(X[i:(i + time_steps)])
        ys.append(y[i + time_steps])
    return np.array(Xs), np.array(ys)

time_steps = 10
X_seq, y_seq = create_sequences(X_scaled, y, time_steps)

# Convert to PyTorch tensors
X_tensor = torch.tensor(X_seq, dtype=torch.float32)
y_tensor = torch.tensor(y_seq, dtype=torch.float32)

# Create DataLoader
dataset = TensorDataset(X_tensor, y_tensor)
class_counts = np.bincount(y)
class_weights = 1. / class_counts
weights = class_weights[y]
weights = torch.tensor(weights, dtype=torch.float32)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights), replacement=True)
print(len(weights))
print(len(dataset))
dataloader = DataLoader(dataset, batch_size=32,sampler =sampler)
# Print some sampled indices for debugging
for i, (X_batch, y_batch) in enumerate(dataloader):
    print(f"Batch {i}: X_batch size: {X_batch.size()}, y_batch size: {y_batch.size()}")
    if i == 2:  # Print only the first 3 batches for brevity
        break

# Define LSTM model
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        h_0 = torch.zeros(1, x.size(0), hidden_size).to(x.device)
        c_0 = torch.zeros(1, x.size(0), hidden_size).to(x.device)
        out, _ = self.lstm(x, (h_0, c_0))
        out = self.fc(out[:, -1, :])
        out = self.sigmoid(out)
        return out

input_size = X_tensor.shape[2]
hidden_size = 50
output_size = 1

model = LSTMModel(input_size, hidden_size, output_size)

# Define weighted BCE loss
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize metrics
accuracy = Accuracy(task='binary')
precision = Precision(task='binary')
recall = Recall(task='binary')
f1_score = F1Score(task='binary')

# Train the model
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    for X_batch, y_batch in dataloader:
        optimizer.zero_grad()
        outputs = model(X_batch)
        #loss = criterion(outputs.squeeze(), y_batch)
        batch_weights = class_weights[y_batch.long()]
        loss = criterion(outputs.squeeze(), y_batch) * torch.tensor(batch_weights)
#         loss = loss * torch.tensor(class_weights
        loss = loss.mean()
        loss.backward()
        optimizer.step()

    # Evaluate metrics
    model.eval()
    with torch.no_grad():
        outputs = model(X_tensor)
        predictions = (outputs.squeeze() > 0.5).float()
        acc = accuracy(predictions, y_tensor)
        prec = precision(predictions, y_tensor)
        rec = recall(predictions, y_tensor)
        f1 = f1_score(predictions, y_tensor)
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}, Accuracy: {acc}, Precision: {prec}, Recall: {rec}, F1 Score: {f1}')

# Final evaluation
model.eval()
with torch.no_grad():
    outputs = model(X_tensor)
    predictions = (outputs.squeeze() > 0.5).float()
    acc = accuracy(predictions, y_tensor)
    prec = precision(predictions, y_tensor)
    rec = recall(predictions, y_tensor)
    f1 = f1_score(predictions, y_tensor)
    print(f'Final Metrics - Accuracy: {acc}, Precision: {prec}, Recall: {rec}, F1 Score: {f1}')

OUTPUT:

429394
429384
Batch 0: X_batch size: torch.Size([32, 10, 17]), y_batch size: torch.Size([32])
Batch 1: X_batch size: torch.Size([32, 10, 17]), y_batch size: torch.Size([32])
Batch 2: X_batch size: torch.Size([32, 10, 17]), y_batch size: torch.Size([32])
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[104], line 72
     70 for epoch in range(num_epochs):
     71     model.train()
---> 72     for X_batch, y_batch in dataloader:
     73         optimizer.zero_grad()
     74         outputs = model(X_batch)

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:674, in _SingleProcessDataLoaderIter._next_data(self)
    672 def _next_data(self):
    673     index = self._next_index()  # may raise StopIteration
--> 674     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    675     if self._pin_memory:
    676         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51, in <listcomp>(.0)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataset.py:208, in TensorDataset.__getitem__(self, index)
    207 def __getitem__(self, index):
--> 208     return tuple(tensor[index] for tensor in self.tensors)

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataset.py:208, in <genexpr>(.0)
    207 def __getitem__(self, index):
--> 208     return tuple(tensor[index] for tensor in self.tensors)

IndexError: index 429389 is out of bounds for dimension 0 with size 429384

Checked the length of the input data and weights to confirm they are all consistent, which they are but I’m still getting the error.

Also, if you have any suggestion to build a better lstm model for anomaly detection,please add your input. I am very new to deep learning

But they are not:

The weights have 10 more values and are thus most likely causing the issue.

Wow, Can’t believe i missed that🥲. Stared at it for hours

Thanks