I am taking in a trace of a model and would like to check what type the tensor is supposed to be. Is there a way to grab this info from torch._C.Node (ie: get float, double, from https://pytorch.org/docs/stable/tensors.html)? Or in general any way to grab that info starting from the trace?
You can grab it in this way:
def some_fn(x): return x + x some_input = torch.randn(2, 2) my_traced_function = torch.jit.trace(some_fn, [some_input]) for input in my_traced_function.graph.inputs(): traced_tensor_type = input.type() # Prints "Float" print(traced_tensor_type.scalarType()) # However, note that the interpreter will still run with differently typed # tensors my_traced_function(torch.ones(2, 2, dtype=torch.long))
I see that this is getting it for inputs but what about intermediate nodes in the graph? It seems inputs are of type torch._C.Value which has a decorated type field but torch._C.Node doesn’t.
Node represents the whole operation (some inputs, an operation, and some outputs). The inputs and outputs are represented as
Values, you can get them like so:
def some_fn(x): return x + x some_input = torch.randn(2, 2) my_traced_function = torch.jit.trace(some_fn, [some_input]) print(my_traced_function.graph) for node in my_traced_function.graph.nodes(): for node_input in node.inputs(): if isinstance(node_input.type(), torch._C.TensorType): print(node_input.type().scalarType())
You can read more about the internal representations here.
Thanks for the link, really helpful, somehow didn’t find it in my initial search. And got it, this makes a lot of sense now.