I have a model that I trained around a year ago, and now when I try to make predictions, I can’t retrieve the names of the classes so I don’t know which class does it refer to.
def predict_image(img, model):
# Convert to a batch of 1
device = torch.device('cpu')
xb = to_device(img.unsqueeze(0), device)
# Get predictions from model
yb = model(xb)
# Pick index with highest probability
_, preds = torch.max(yb, dim=1)
# Retrieve the class label
return train_ds.classes[preds[0].item()]
I have no access to train_ds