I am taking in a trace of a model and would like to check what type the tensor is supposed to be. Is there a way to grab this info from torch._C.Node (ie: get float, double, from https://pytorch.org/docs/stable/tensors.html)? Or in general any way to grab that info starting from the trace?
You can grab it in this way:
def some_fn(x):
return x + x
some_input = torch.randn(2, 2)
my_traced_function = torch.jit.trace(some_fn, [some_input])
for input in my_traced_function.graph.inputs():
traced_tensor_type = input.type()
# Prints "Float"
print(traced_tensor_type.scalarType())
# However, note that the interpreter will still run with differently typed
# tensors
my_traced_function(torch.ones(2, 2, dtype=torch.long))
I see that this is getting it for inputs but what about intermediate nodes in the graph? It seems inputs are of type torch._C.Value which has a decorated type field but torch._C.Node doesn’t.
Node
represents the whole operation (some inputs, an operation, and some outputs). The inputs and outputs are represented as Value
s, you can get them like so:
def some_fn(x):
return x + x
some_input = torch.randn(2, 2)
my_traced_function = torch.jit.trace(some_fn, [some_input])
print(my_traced_function.graph)
for node in my_traced_function.graph.nodes():
for node_input in node.inputs():
if isinstance(node_input.type(), torch._C.TensorType):
print(node_input.type().scalarType())
You can read more about the internal representations here.
2 Likes
Thanks for the link, really helpful, somehow didn’t find it in my initial search. And got it, this makes a lot of sense now.
1 Like