How to perform efficient chunking for large tensors?

I have a large tensor that is an input to my model, and the first layer is a non-learned embedding which expands the tensor from (40960000, 3 → 40960000, 60). In practice, this data doesn’t fit through my model with my VRAM (24 GB) because pytorch allocates almost 20 GB for this embedding function. My assumption is that the intermediary steps to build the embedding require additional VRAM. I would like to chunk the input data to reduce the amount of memory I need at one time.

However, in practice, I have noticed that pytorch still allocates 20 GB for this process even with chunking. My expectation was that the program would enter the embedding function and create a stack frame, where GPU memory must be allocated to perform the requested operations. However, I assumed that this memory would be reused, so the total memory allocated would be (input tensor + the memory for the stack frame (scales with chunk size) + output tensor), which should be around 10-11 GB. Clearly it doesn’t work this way, and it seems to allocate memory every time it runs the embedding function.

I’ve created a minimal example that highlights my issue, Any feedback/wisdom here would be appreciated.

import torch
import torch.nn as nn 
import time

class FFEmbedding(nn.Module):
    """
    Map vectors into a higher dimensional space to learn higher-frequency
    functions.

    length: the length (L) of the embedding

    returns a flat tensor with length as L, e.g.:
    [sin(2**0 * x), cos(2**0 * x), ..., sin(2**L-1 * x), cos(2**L-1 * x)]
    
    """
    def __init__(self, length):
        super(FFEmbedding, self).__init__()
        self.length = length

    def forward(self, x):
        with torch.no_grad():
            vector = [x]
            for i in range(self.length):
                for fn in [torch.sin, torch.cos]:
                    vector += [fn(2.**i * x)]
            return torch.concat(vector, dim=-1)

cuda = False
dev = "cuda:0" if cuda else "cpu"

start_time = time.time()
test = torch.rand((40960000, 3)).to(device=dev)
emb = FFEmbedding(10).to(device=dev)

# Size (GB) = 40960000 * 3 * 4 bytes / 1024**3 = .45 GB
# Observed memory is around 2 GB
print("Size of initial tensor:")
print(test.nelement() * test.element_size() / 1024**3)

chunk_size=850*32  # 1506-chunks
# chunk_size=test.shape[0]  # 1-chunk
chunks = []
iters = 0
for i in range(0, test.shape[0], chunk_size):
    iters += 1
    chunks.append(emb(test[i:i+chunk_size]))

print(f"processed {iters=}")
test2 = torch.cat(chunks, dim=0)

# Size (GB) = ~10 GB
# Observed size: ~20 GB
print(f"Final tensor size: {test2.nelement() * test2.element_size() / 1024**3}")

# 1-Chunk GPU/CPU Time: 1.64/9.91 seconds
# 1506-Chunk GPU/CPU: 1.8/5.45 seconds
print(f"Completed operation in {time.time() - start_time}")

input("Check VRAM allocation")

I kept exploring this and found that initializing the output tensor and then iteratively assigning to it chunk-by-chunk fixed my issue. Would love to have more information about what happens in append that causes the memory to explode. I think there might be a way to leverage torch.cuda.empty_cache(), but I had trouble employing that in a way that made sense for my design.

test2 = torch.empty((40960000, 63), device=dev)
for i in range(0, test.shape[0], chunk_size):
    test2[i:i+chunk_size] = emb(test[i:i+chunk_size])