I am currently working on dealing with the task vector of the llama’s parameters, which is huge (over 13015864320). And I have a tensor shape (3,13015864320) in float16. Every time I try to get kthvalue or do element wise multiplication on this Tensor, RAM will always go unreasonably high (over 300GB). Is there any way to get this Tensor in batch and do calculations sequentially?
Moreover, I found that Numpy with exact same computation (partition replace kth_value and np.multiply instead *)would use much less memory, can someone explain why this would happen?
Batch Processing: Instead of processing the entire tensor at once, you can process it in smaller batches. This reduces the memory footprint as only a portion of the tensor is loaded into memory at any given time.
import torch
def process_in_batches(tensor, batch_size):
num_batches = tensor.size(1) // batch_size
results =
for i in range(num_batches):
batch = tensor[:, i*batch_size:(i+1)*batch_size]
# Perform your operations here
# Example: kth_value = torch.kthvalue(batch, k)
result = batch # Placeholder for actual operation
results.append(result)
# Handle the remainder
if tensor.size(1) % batch_size != 0:
batch = tensor[:, num_batches*batch_size:]
result = batch # Placeholder for actual operation
results.append(result)
return torch.cat(results, dim=1)
Example usage
tensor = torch.randn(3, 13015864320, dtype=torch.float16)
batch_size = 1000000 # Adjust this based on your memory constraints
processed_tensor = process_in_batches(tensor, batch_size)