ValueError: Expected input batch_size (512) to match target batch_size (6815744)

Please help me fix this issue. Thanks in advance.

My code:
import torch
import torch.nn as nn
import torch.optim as optim

class MLP(nn.Module):
def init(self, input_dim, hidden_dim, output_dim):
super(MLP, self).init()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim

    self.fc1 = nn.Linear(input_dim, hidden_dim)
    self.fc2 = nn.Linear(hidden_dim, output_dim)
    
def forward(self, x):
    x = x.view(-1, self.input_dim)
    x = nn.functional.relu(self.fc1(x))
    x = self.fc2(x)
    print(x.shape)
    return x

Load features and labels

features_list = []
for i in range(13):
xx = “{:02d}”.format(i)
feature_file = f’C:/Users/prave/AppData/Local/Programs/Python/Python39/Scripts/assigment_data/assigment/feats/feats_{xx}.pt’
feature_tensor = torch.load(feature_file)
features_list.append(feature_tensor)

labels_list = []
for i in range(13):
xx = “{:02d}”.format(i)
label_file = f’C:/Users/prave/AppData/Local/Programs/Python/Python39/Scripts/assigment_data/assigment/labels/y_{xx}.pt’
label_tensor = torch.load(label_file)
labels_list.append(label_tensor)

print(labels_list[0].shape)

Concatenate the labels into a single tensor

all_labels = torch.cat(labels_list, dim=0)
print(all_labels.shape)

Find the number of unique labels (the number of classes)

num_classes = len(torch.unique(all_labels))
print(f"Number of classes: {num_classes}")

Calculate input dimension

input_dim = 524288
print(“Input dimension:”, input_dim)

Define the model and optimizer

model = MLP(input_dim=input_dim, hidden_dim=128, output_dim=18)
optimizer = optim.Adam(model.parameters(), lr=0.001)

Train the model

criterion = nn.CrossEntropyLoss()

for epoch in range(5):
running_loss = 0.0
total_pixels = 0
correct_pixels = 0

for i in range(len(features_list)):
    features = features_list[i]
    labels = labels_list[i]

    # Upscale the features to match the size of the labels
    features = features.unsqueeze(0)[:64]  # Add a batch dimension
    features = nn.functional.interpolate(features, scale_factor=4, mode='bilinear', align_corners=False)
    features = features.squeeze(0)  # Remove the batch dimension
    features = features[:labels.size(0)]  # Match the batch size of the labels
    
    # Reshape the features tensor to (batch_size, number_channels*height*width)
    features = features.view(features.size(0), -1)
    assert features.size(0) == labels.size(0)

    print(f"Shape of the feature tensor before pass: {features.shape}")
    print(f"Shape of the label tensor before pass: {labels.shape}")
    
    batch_size = features.size(0)
    labels = labels.view(batch_size, -1)
    
    # Flatten the labels tensor
    labels = torch.cat(labels_list, dim=0).view(-1)

    # Print the shape of the flattened labels tensor
    print(f"Shape of the label tensor after flattening: {labels.shape}")

    # Forward pass
    outputs = model(features)
    
    # Compute loss
    loss = criterion(outputs, labels)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Compute accuracy
    predicted = torch.max(outputs.data, 1)[1]
    total_pixels += labels.numel()
    correct_pixels += (predicted == labels).sum().item()

    running_loss += loss.item()

epoch_loss = running_loss / len(features_list)
epoch_acc = correct_pixels / total_pixels

print('Epoch [{}/{}], Loss: {:.4f}, Pixel Accuracy: {:.4f}'.format(epoch+1, 10, epoch_loss, epoch_acc))

Error:

ValueError Traceback (most recent call last)
Cell In[2], line 94
91 outputs = model(features)
93 # Compute loss
—> 94 loss = criterion(outputs, labels)
96 # Backward and optimize
97 optimizer.zero_grad()

File c:\users\prave\appdata\local\programs\python\python39\lib\site-packages\torch\nn\modules\module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don’t have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []

File c:\users\prave\appdata\local\programs\python\python39\lib\site-packages\torch\nn\modules\loss.py:1174, in CrossEntropyLoss.forward(self, input, target)
1173 def forward(self, input: Tensor, target: Tensor) → Tensor:
→ 1174 return F.cross_entropy(input, target, weight=self.weight,
1175 ignore_index=self.ignore_index, reduction=self.reduction,
1176 label_smoothing=self.label_smoothing)

File c:\users\prave\appdata\local\programs\python\python39\lib\site-packages\torch\nn\functional.py:3026, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3024 if size_average is not None or reduce is not None:
3025 reduction = _Reduction.legacy_get_string(size_average, reduce)
→ 3026 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

ValueError: Expected input batch_size (512) to match target batch_size (6815744).

Some of the print statement output:
torch.Size([512, 1024])
Number of classes: 18
Input dimension: 524288
Shape of the feature tensor before pass: torch.Size([512, 524288])
Shape of the label tensor before pass: torch.Size([512, 1024])
Shape of the label tensor after flattening: torch.Size([6815744])
torch.Size([512, 18]) → x