Can I use pytorch like a gpu shader? (Run a function on each x,y,c index.)

I’m working on a library to display cameras and image functions nicely, and I have an example of Conway’s Game of Life here. It runs decently with numpy, but I remember running it at much larger sizes on ten year old computers.

I can check for pytorch and cuda, but is there a way to pass in the x,y,channel or x,y index into a function, and run that function per pixel and color?

Ah! I found something:

The first part should help with 2D, and the second with 3D. Not sure if the writer of that article was using a GPU, but it should help a lot.

I got it working if anyone’s interested.

I couldn’t figure out how to use the coords though. I still need help with that.

So you have an image and want to run a function that operates on the image per pixel. I don’t think pytorch would be much help here, but if you want to use pytorch, you can flatten the image and treat every entry as a different data point. So if you have 5x5x3 image you would get 75 data points.

Correct me if I am wrong in some case, cause I have little knowledge in computer graphics.

1 Like

I figured it out. I’ll post the full solution once I have time, but the key is too use meshgrids, operate on different axes of that mgrid for each dimension, then do: array_out[mgrid_original]=array_in[mgrid_modified]

Alright, I’ve got this fully implemented here:

Now I just need to get tests on it and upload to pypi.

For some reason I don’t completely understand, this setup works quite well:

self.min_bounds = [0 for _ in frame.shape]
self.max_bounds = list(frame.shape)
grid_slices = [slice(self.min_bounds[d], self.max_bounds[d]) for d in range(len(frame.shape))]
space_grid = np.mgrid[grid_slices]
space_grid.flags.writeable = False
x_tens = torch.LongTensor(space_grid[0, ...]).to(self.device)
y_tens = torch.LongTensor(space_grid[1, ...]).to(self.device)
c_tens = torch.LongTensor(space_grid[2, ...]).to(self.device)
self.x = Variable(x_tens, requires_grad=False)
self.y = Variable(y_tens, requires_grad=False)
self.c = Variable(c_tens, requires_grad=False)

then, passing x,y,c into coords, I can do something like:

>>> img = np.zeros((600, 800, 3))
>>> def fun(array, coords, finished):
...     rgb = torch.empty(array.shape).uniform_(0,1).type(torch.DoubleTensor).to(array.device)/300.0
...     trans = np.zeros_like(coords)
...     trans[0,...] = np.ones(trans.shape[1:])
...     array[coords] = (array[coords+trans] + rgb[coords])%1.0
>>> VideoHandlerThread(video_source=img, callbacks=pixel_shader(fun)).display()

Strangely, using pure numpy for the meshgrid slows things down, and using pure pytorch overloads my GPU.