Creating a basic CNN model for binary image classification. Tried evaluating my model before training but receive the runtime error for dtypes mismatch as the BCE loss function uses it’s weights in float form and it seems as though my inputs are Tensors containing Long data. I don’t believe I ever converted my data into Long but changed all the relevant tensors to float type anyways in the validation step method definition:
def validation_step(self, batch):
images, targets = batch
targets.to(torch.float32)
targets = targets.view(-1,1)
images.to(torch.float32)
out = self(images)
out.to(torch.float32)
loss = F.binary_cross_entropy(out, targets) # Calculate loss
score = binary_acc(out, targets)
return {'val_loss': loss.detach(), 'val_score': score.detach() }
However still receiving the same error message. This is the full notebook https://jovian.ml/moaaz645/malaria-cnn-basic
And this is the error message:
RuntimeError Traceback (most recent call last)
in
1 model = to_device(MalariaCnnModel(), device)
----> 2 evaluate(model, val_dl)
/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
13 def decorate_context(*args, **kwargs):
14 with self:
—> 15 return func(*args, **kwargs)
16 return decorate_context
17
in evaluate(model, val_loader)
2 def evaluate(model, val_loader):
3 model.eval()
----> 4 outputs = [model.validation_step(batch) for batch in val_loader]
5 return model.validation_epoch_end(outputs)
6
in (.0)
2 def evaluate(model, val_loader):
3 model.eval()
----> 4 outputs = [model.validation_step(batch) for batch in val_loader]
5 return model.validation_epoch_end(outputs)
6
in validation_step(self, batch)
17 out = self(images)
18 out.to(torch.float32)
—> 19 loss = F.binary_cross_entropy(out, targets) # Calculate loss
20 score = binary_acc(out, targets)
21 return {‘val_loss’: loss.detach(), ‘val_score’: score.detach() }
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
2377
2378 return torch._C._nn.binary_cross_entropy(
-> 2379 input, target, weight, reduction_enum)
2380
2381
RuntimeError: expected dtype Float but got dtype Long