I would like to add a module graph to tensorboard using SummaryWriter.add_graph
. My module’s forward
method requires some *args
and some **kwargs
.
Walking down the stack, it looks like add_graph
calls torch.jit.get_trace_graph(model, args)
, i.e. not passing any kwargs, even though torch.jit.get_trace_graph
would support it.
Is there a way to pass kwargs? The only solution I see right now is inspect
ing the foward
parameters and rearranging the arguments so that they are all converted to args, which I would really like to avoid.