Given the model below that returns a list of tensors (logits).
class MyNet(nn.Module):
def __init__(self, n_classes):
super(MyNet, self).__init__()
...
def forward(self, x):
...
return [side_5, side_6, side_7, side_8]
How can I properly use the tensorboardX function writer.add_graph(model, input)
?
I tried something like:
dummy_input = Variable(torch.zeros([1, 3, im_rows, im_cols]).to(self.device), requires_grad=True)
writer.add_graph(model, dummy_input)
but it seems that the add_graph
expects the model to return a tensor and not a list of tensors. Any ideas?