Hi there!
I am having some trouble training a model from scratch. I am dealing with a dataset that contains images and their assigned label, 0 or 1. I aim to classify them using a ResNet architechture, which has the two class probabilities as an output. Hence, for each batch of images trained, I obtain a tensor of size (batch_size, 2). As the ground truth, I have a tensor of size (batch_size), with the true labels for each image.
As the training loss function, I am using the Focal Loss (torchvision.ops.focal_loss — Torchvision 0.15 documentation) from Pytorch. To compute the loss I need to transform my predicted probabilities to the binary target, so I performed argmax() to get the label with higher probability. Nevertheless, this detaches the loss calculation from the computation graph (as read here: Element 0 of tensors does not require grad and does not have a grad_fn How can i fix this - #2 by ptrblck).
My issue then is that I don´t know how to transform my output probabilities to labels without it affecting the gradients.
Here the code:
from torch.optim.lr_scheduler import LambdaLR
import torch.nn as nn
sm = nn.Softmax(dim=1)
# Create a custom learning rate scheduler
def lr_lambda(epoch):
return lr_2 if epoch > 80 else lr_1
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
# Training loop
for epoch in range(total_epochs):
print(f"Epoch {epoch+1}")
# Set the model to training mode
model.train()
# Iterate over the training data
for i, data in enumerate(train_loader):
inputs, labels = data
# Process the batch inputs and labels
inputs = inputs.float().to(device)
labels = labels.float().to(device)
# Forward pass
outputs = model(inputs)
predicted_probs = sm(outputs)
predicted_probs = predicted_probs.squeeze(-1).squeeze(-1)
# Compute loss
loss = sigmoid_focal_loss(predicted_probs, labels, gamma=3) # CONFLICT!!
# Clear parameter gradients
optimizer.zero_grad()
# Backward pass and optimization
loss.backward()
optimizer.step()
# Update the learning rate based on the custom scheduler
scheduler.step()
I get the following error in the line of conflict:
ValueError: Target size (torch.Size([32])) must be the same as input size (torch.Size([32, 2]))
Thanks in advance for all your response.