Hello, I am developing a program that will use auto-differentiation to iteratively reconstruct a multi-dimensional input tensor. The issue I am facing is that the forward model has a step that requires a very large tensor when fully vectorized. Giving a simplified example: the input tensor T1 has dimensions (dz, dy, dx), and it is reshaped to dimensions (dzdydx, 1, 1). A second tensor T2 has shape (1, dy, dx), and I need to perform the following operation:
out = torch.exp( T1 * T2 * -1j).sum(dim=0)
So my output is only of dimensions (dy, dx), but because a function must be performed on every input voxel the maximum tensor size in the middle is (dzdydx, dy, dx). This very quickly becomes larger than my available GPU memory, and I therefore have two questions.
First, is there a clever way of avoiding the large tensor that I’m not seeing? Due to the exponential after multiplying but before summing it’s not as easy as a matmul, but the ideal solution would be to find a vectorized way of performing this calculation that doesn’t scale terribly.
Second, assuming I am stuck with the large tensor, what is the fastest way of performing this calculation? At one extreme I could effectively loop over each of the input voxels, summing the output tensor as I go. This is of course very slow but requires no additional memory. I’m currently slicing over the reshaped input tensor and performing the vectorized calculation one chunk at a time, but again this is slow and it feels like there should be a better way.
Any help is appreciated, thanks!