Getting leaf nodes of graph directly from torch.autograd.graph.Node

Hello I am trying to directly access specific tensors which are located in the computation graph, however each grad_fn/Node seems to store temporary tensors differently:

b = lin(a)
dir(b.grad_fn)
[‘call’, ‘class’, ‘delattr’, ‘dir’, ‘doc’, ‘eq’, ‘format’, ‘ge’, ‘getattribute’, ‘gt’, ‘hash’, ‘init’, ‘init_subclass’, ‘le’, ‘lt’, ‘ne’, ‘new’, ‘reduce’, ‘reduce_ex’, ‘repr’, ‘setattr’, ‘sizeof’, ‘str’, ‘subclasshook’, ‘_raw_saved_mat1’, ‘_raw_saved_mat2’, ‘_register_hook_dict’, ‘_saved_alpha’, ‘_saved_beta’, ‘_saved_mat1’, ‘_saved_mat1_sym_sizes’, ‘_saved_mat1_sym_strides’, ‘_saved_mat2’, ‘_saved_mat2_sym_sizes’, ‘_saved_mat2_sym_strides’, ‘metadata’, ‘name’, ‘next_functions’, ‘register_hook’, ‘register_prehook’, ‘requires_grad’]

The tensor input a is stored as _saved_mat_1

However when for this example

c = torch.randn(3,32,32)
conv = torch.nn.Conv2d(3,16,1)
d = conv(c)
dir(d.grad_fn.next_functions[0][0])
[‘call’, ‘class’, ‘delattr’, ‘dir’, ‘doc’, ‘eq’, ‘format’, ‘ge’, ‘getattribute’, ‘gt’, ‘hash’, ‘init’, ‘init_subclass’, ‘le’, ‘lt’, ‘ne’, ‘new’, ‘reduce’, ‘reduce_ex’, ‘repr’, ‘setattr’, ‘sizeof’, ‘str’, ‘subclasshook’, ‘_raw_saved_input’, ‘_raw_saved_weight’, ‘_register_hook_dict’, ‘_saved_bias_sym_sizes_opt’, ‘_saved_dilation’, ‘_saved_groups’, ‘_saved_input’, ‘_saved_output_padding’, ‘_saved_padding’, ‘_saved_stride’, ‘_saved_transposed’, ‘_saved_weight’, ‘metadata’, ‘name’, ‘next_functions’, ‘register_hook’, ‘register_prehook’, ‘requires_grad’]

Now the tensor input is stored under _saved_tensor.

For educational purposes I am trying to develop a system that allows distributed (model parallel) training across remote devices. So I am storing disconnected nn.Modules on each of the devices, however in order to complete the backward pass I must pass the gradients between the devices, so I need a way to easily access the input’s gradient. If there is an easier way to re collect the inputs other than traversing through the computation graph that would also help, however I am trying to make it so I have to change as little as possible about already existing nn.Modules and thus do not want to make any changes to there forward passes.

Thanks (sorry for any typos)