Meshgrid in pytorch

In tensorflow, creating a meshgrid is pretty easy

x_t, y_t = tf.meshgrid(tf.linspace(0.0,   _width_f - 1.0,  _width), 
                       tf.linspace(0.0 , _height_f - 1.0 , _height))

How can I create a meshgrid in pytorch?

My try:

a = torch.linspace(0.0, _width_f - 1.0,  _width)
b = torch.linspace(0.0 , _height_f - 1.0 , _height)
x_t = a.view(-1, 1).repeat(1, b.size(0))
y_t = b.view(1, -1).repeat(a.size(0), 1)
2 Likes

This seems to work:

a = torch.linspace(0.0, _width_f - 1.0,  _width)
b = torch.linspace(0.0 , _height_f - 1.0 , _height)
x_t = a.repeat(_height)
y_t = b.repeat(_width,1).t().contiguous().view(-1)

You can also do this to get X and Y values of a meshgrid:

xv, yv = torch.meshgrid([torch.arange(0,5), torch.arange(0,10)])

I am using PyTorch 0.4.1.

11 Likes