I was trying to replicate the examples in here, but my graphs were not showing any tensor shapes. I have then implemented the change cjsoft suggested by substituting the default
m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs())))
if m:
shape = m.group(1)
shape = shape.split(",")
shape = tuple(map(int, shape))
else:
shape = None
return shape
in get_shape within pytorch_builder.py with
shape = torch_node.output().type().sizes()
return shape
which solved the issue. However, when trying to plot my own network (a ResNet U-Net), it incurred this error:
RuntimeError: r INTERNAL ASSERT FAILED at "..\\aten\\src\\ATen/core/jit_type.h":171, please report a bug to PyTorch.
So I have wrapped everything into a try/except and it seems to work fine:
try:
shape = torch_node.output().type().sizes()
except:
shape = None
return shape
Hope it helps.