Get appropriate model in/output type programmatically

Hi all,

Of course, the datatype of the data you’re trying to propagate through your network matters.
Feeding your model input with wrong input type might give you:
RuntimeError: Input type (torch.LongTensor) and weight type (torch.FloatTensor) should be the same
And feeding model output of wrong type might yield something along the lines of:
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

To build a more flexible system I was wondering: How can one find the data type that will fit a model programmatically?

For output I’ve tried list(model.parameters())[-1] which might be a bit naive.

Might not be the best way, but for the inputs you could do something like:

current_children = model
while list(current_children.children()) != []: 
   current_children = list(current_children.children())[0]

for p_name, p in current_children.named_parameters():
    print('{}: {}'.format(p_name, p.dtype))