I have a large tensor that is an input to my model, and the first layer is a non-learned embedding which expands the tensor from (40960000, 3 → 40960000, 60). In practice, this data doesn’t fit through my model with my VRAM (24 GB) because pytorch allocates almost 20 GB for this embedding function. My assumption is that the intermediary steps to build the embedding require additional VRAM. I would like to chunk the input data to reduce the amount of memory I need at one time.

However, in practice, I have noticed that pytorch still allocates 20 GB for this process even with chunking. My expectation was that the program would enter the embedding function and create a stack frame, where GPU memory must be allocated to perform the requested operations. However, I assumed that this memory would be reused, so the total memory allocated would be `(input tensor + the memory for the stack frame (scales with chunk size) + output tensor)`

, which should be around 10-11 GB. Clearly it doesn’t work this way, and it seems to allocate memory every time it runs the embedding function.

I’ve created a minimal example that highlights my issue, Any feedback/wisdom here would be appreciated.

```
import torch
import torch.nn as nn
import time
class FFEmbedding(nn.Module):
"""
Map vectors into a higher dimensional space to learn higher-frequency
functions.
length: the length (L) of the embedding
returns a flat tensor with length as L, e.g.:
[sin(2**0 * x), cos(2**0 * x), ..., sin(2**L-1 * x), cos(2**L-1 * x)]
"""
def __init__(self, length):
super(FFEmbedding, self).__init__()
self.length = length
def forward(self, x):
with torch.no_grad():
vector = [x]
for i in range(self.length):
for fn in [torch.sin, torch.cos]:
vector += [fn(2.**i * x)]
return torch.concat(vector, dim=-1)
cuda = False
dev = "cuda:0" if cuda else "cpu"
start_time = time.time()
test = torch.rand((40960000, 3)).to(device=dev)
emb = FFEmbedding(10).to(device=dev)
# Size (GB) = 40960000 * 3 * 4 bytes / 1024**3 = .45 GB
# Observed memory is around 2 GB
print("Size of initial tensor:")
print(test.nelement() * test.element_size() / 1024**3)
chunk_size=850*32 # 1506-chunks
# chunk_size=test.shape[0] # 1-chunk
chunks = []
iters = 0
for i in range(0, test.shape[0], chunk_size):
iters += 1
chunks.append(emb(test[i:i+chunk_size]))
print(f"processed {iters=}")
test2 = torch.cat(chunks, dim=0)
# Size (GB) = ~10 GB
# Observed size: ~20 GB
print(f"Final tensor size: {test2.nelement() * test2.element_size() / 1024**3}")
# 1-Chunk GPU/CPU Time: 1.64/9.91 seconds
# 1506-Chunk GPU/CPU: 1.8/5.45 seconds
print(f"Completed operation in {time.time() - start_time}")
input("Check VRAM allocation")
```