Tensor Type for torch._C.Node

I am taking in a trace of a model and would like to check what type the tensor is supposed to be. Is there a way to grab this info from torch._C.Node (ie: get float, double, from https://pytorch.org/docs/stable/tensors.html)? Or in general any way to grab that info starting from the trace?

You can grab it in this way:

def some_fn(x):
	return x + x

some_input = torch.randn(2, 2)
my_traced_function = torch.jit.trace(some_fn, [some_input])

for input in my_traced_function.graph.inputs():
	traced_tensor_type = input.type()
	# Prints "Float"
	print(traced_tensor_type.scalarType())

# However, note that the interpreter will still run with differently typed
# tensors
my_traced_function(torch.ones(2, 2, dtype=torch.long))

I see that this is getting it for inputs but what about intermediate nodes in the graph? It seems inputs are of type torch._C.Value which has a decorated type field but torch._C.Node doesn’t.

Node represents the whole operation (some inputs, an operation, and some outputs). The inputs and outputs are represented as Values, you can get them like so:

def some_fn(x):
	return x + x

some_input = torch.randn(2, 2)
my_traced_function = torch.jit.trace(some_fn, [some_input])
print(my_traced_function.graph)

for node in my_traced_function.graph.nodes():
    for node_input in node.inputs():
        if isinstance(node_input.type(), torch._C.TensorType):
            print(node_input.type().scalarType())

You can read more about the internal representations here.

2 Likes

Thanks for the link, really helpful, somehow didn’t find it in my initial search. And got it, this makes a lot of sense now.

1 Like