How to efficiently allocate GPU memory?

So I have a toy attractor model:

points = torch.rand((2**16,3), dtype=torch.float32, device=cuda, requires_grad=False)
def attractor_func(v, a = 10.0, b = 28.0, c = 8.0 / 3.0):
	dv = torch.zeros(points.shape, dtype=torch.float32, device=cuda, requires_grad=False)

	#lorenz
	dv[:, 0] = a * (v[:, 1] - v[:, 0])
	dv[:, 1] = v[:, 0] * (b - v[:, 2]) - v[:, 1]
	dv[:, 2] = v[:, 0] * v[:, 1] - c * v[:, 2]

	return dv

This function is called 4 times from an Runge-Kutta function that is called as quickly as the cpu and gpu can manage.

def rk4(func, h, v):
	k1 = func(v)
	k2 = func(v + (h / 2.0) * k1)
	k3 = func(v + (h / 2.0) * k2)
	k4 = func(v + h * k3)
	return h * (k1 / 6.0 + k2 / 3.0 + k3 / 3.0 + k4 / 6.0)

with this called in between “rendering”:

for _ in range(32):
		points += rk4(attractor_func, 0.01, points)

But something tells me that this is not the most efficient code when it comes to managing GPU memory. I can’t even convert attractor_func to a JIT script, if that is of any use.

Perhaps there is a way to vectorize the dv slices across the second dimension as well?

If you write it as torch.chunk + pointwise + torch.cat, it should work with the JIT fuser.

It occurred to me that the thomas attractor was too simple in allowing the concentration of point-wise operations, that was not the general case I wished to solve for. I changed it to lorenz.

@tom

TorchScript fails:

  File "./main.py", line 37, in rk4
    k2 = func(v + (h / 2.0) * k1)
RuntimeError: default_program(102): error: identifier "__ldg" is undefined

1 error detected in the compilation of "default_program".

with:

@torch.jit.script
def attractor_func(v):
	x, y, z = torch.chunk(v, 3, dim=1)

	#lorenz
	a = 10.0
	b = 28.0
	c = 8.0 / 3.0
	dx = a * (y - x)
	dy = x * (b - z) - y
	dz = x * y - c * z

	return torch.cat([dx, dy, dz], 1)

Without JIT it works.

My card is compute 3.0 (pytorch compiled from source w/ cuda 10.2)

Ohoh. I think that is an incompatibility of your card which isn’t supported anymore. You could go and delete the __ldg in the source code and re-compile from scratch. I will cost you when fusing broadcasting operations.

Best regards

Thomas