Example code to put matplotlib graph to Tensorboard X

While working on timeseries data, wanted to save plots as images in tensorboard.

Here is the code.

import io
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
import PIL.Image
from torchvision.transforms import ToTensor

def gen_plot():
    """Create a pyplot plot and save to buffer."""
    plt.plot([1, 2])
    buf = io.BytesIO()
    plt.savefig(buf, format='jpeg')
    return buf

# Prepare the plot
plot_buf = gen_plot()

image = PIL.Image.open(plot_buf)
image = ToTensor()(image).unsqueeze(0)

writer = SummaryWriter(comment='hello imaage')
#x = torchvision.utils.make_grid(image, normalize=True, scale_each=True)
for n_iter in range(100):
    if n_iter % 10 == 0:
        writer.add_image('Image', image, n_iter)

Thanks, this was helpful.

Also, in case you are running from Jupyter, setting the backend to ‘Agg’ at the time of matplotlib helped. Otherwise, it kept giving me errors like, KeyError on PNG.

Thanks for the example.

For me it got an AssertionError

AssertionError: size of input tensor and input format are different.         tensor shape: (1, 432, 288, 3), input_format: CHW

I ended up using writer.add_figure() which is much more concise