Of course, the datatype of the data you’re trying to propagate through your network matters.
Feeding your model input with wrong input type might give you:
RuntimeError: Input type (torch.LongTensor) and weight type (torch.FloatTensor) should be the same
And feeding model output of wrong type might yield something along the lines of:
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'
To build a more flexible system I was wondering: How can one find the data type that will fit a model programmatically?
For output I’ve tried
list(model.parameters())[-1] which might be a bit naive.