How to obrain 3D mesgrid in pytorch

I have a code to obrain 2D mesgrid in pytorch such as

import torch
x = torch.randn(1,2,3,4)
B, C, H, W = x.size()
# mesh grid 
xx = torch.arange(0, W).view(1,-1).repeat(H,1)
yy = torch.arange(0, H).view(-1,1).repeat(1,W)
xx = xx.view(1,1,H,W).repeat(B,1,1,1)
yy = yy.view(1,1,H,W).repeat(B,1,1,1)
grid = torch.cat((xx,yy),1).float()

For 3D, we have x size of BxCxDxHxW. What should I add to obtain 3D mesgrid? Thanks
This is what I tried

import torch
x = torch.randn(1,2,3,4,5)
B, C, D, H, W = x.size()
# mesh grid 
xx = torch.arange(0, W).view(1,-1).repeat(D,H,1)
yy = torch.arange(0, H).view(-1,1).repeat(D,1,W)
zz = torch.arange(0, D).view(1,-1).repeat(1,H,W)
print (xx.shape,yy.shape,zz.shape)
xx = xx.view(1,1,D,H,W).repeat(B,1,1,1,1)
yy = yy.view(1,1,D,H,W).repeat(B,1,1,1,1)
zz = zz.view(1,1,D,H,W).repeat(B,1,1,1,1)
grid = torch.cat((xx,yy,zz),1).float()
print (xx.shape,yy.shape,zz.shape)
print (grid)

The z direction looks wrong

torch.Size([3, 4, 5]) torch.Size([3, 4, 5]) torch.Size([1, 4, 15])
torch.Size([1, 1, 3, 4, 5]) torch.Size([1, 1, 3, 4, 5]) torch.Size([1, 1, 3, 4, 5])
tensor([[[[[0., 1., 2., 3., 4.],
           [0., 1., 2., 3., 4.],
           [0., 1., 2., 3., 4.],
           [0., 1., 2., 3., 4.]],

          [[0., 1., 2., 3., 4.],
           [0., 1., 2., 3., 4.],
           [0., 1., 2., 3., 4.],
           [0., 1., 2., 3., 4.]],

          [[0., 1., 2., 3., 4.],
           [0., 1., 2., 3., 4.],
           [0., 1., 2., 3., 4.],
           [0., 1., 2., 3., 4.]]],


         [[[0., 0., 0., 0., 0.],
           [1., 1., 1., 1., 1.],
           [2., 2., 2., 2., 2.],
           [3., 3., 3., 3., 3.]],

          [[0., 0., 0., 0., 0.],
           [1., 1., 1., 1., 1.],
           [2., 2., 2., 2., 2.],
           [3., 3., 3., 3., 3.]],

          [[0., 0., 0., 0., 0.],
           [1., 1., 1., 1., 1.],
           [2., 2., 2., 2., 2.],
           [3., 3., 3., 3., 3.]]],


         [[[0., 1., 2., 0., 1.],
           [2., 0., 1., 2., 0.],
           [1., 2., 0., 1., 2.],
           [0., 1., 2., 0., 1.]],

          [[2., 0., 1., 2., 0.],
           [1., 2., 0., 1., 2.],
           [0., 1., 2., 0., 1.],
           [2., 0., 1., 2., 0.]],

          [[1., 2., 0., 1., 2.],
           [0., 1., 2., 0., 1.],
           [2., 0., 1., 2., 0.],
           [1., 2., 0., 1., 2.]]]]])

How about using torch.meshgrid(torch.arange(4), torch.arange(5), torch.arange(6))? It will return three tensors that you can produce 3d coordinates.

2 Likes