Use module with **kwargs in SummaryWriter.add_graph

I would like to add a module graph to tensorboard using SummaryWriter.add_graph. My module’s forward method requires some *args and some **kwargs.

Walking down the stack, it looks like add_graph calls torch.jit.get_trace_graph(model, args), i.e. not passing any kwargs, even though torch.jit.get_trace_graph would support it.

Is there a way to pass kwargs? The only solution I see right now is inspecting the foward parameters and rearranging the arguments so that they are all converted to args, which I would really like to avoid.

1 Like