Why index_points() use so much memory?

code snippet is below.

import torch
from torch.profiler import profile, ProfilerActivity
import logging

def index_points(points, indices):
    """
    points.shape = (b, n, c)
    indices.shape = (b, m, k)
    return res.shape = (b, m, k, c)
    """
    _, _, c = points.shape
    _, m, _ = indices.shape

    indices = indices.unsqueeze(dim=3).expand(-1, -1, -1, c)
    points = points.unsqueeze(dim=1).expand(-1, m, -1, -1)
    res = points.gather(dim=2, index=indices)

    return res

def index_points_2(points, indices):
    """
    points.shape = (b, n, c)
    indices.shape = (b, m, k)
    return res.shape = (b, m, k, c)
    """
    device = points.device
    b, m, k = indices.shape

    batch_indices = torch.arange(b, device=device).view(b, 1, 1).expand(-1, m, k)
    res = points[batch_indices, indices, :]
    return res

if __name__ == '__main__':
    b, n, m, k, c = 16, 512, 128, 64, 128
    features = torch.randn((b, n, c), device='cuda:0', requires_grad=True)
    indices = torch.randint(0, n, (b, m, k), device='cuda:0')
    
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True) as prof:
        res = index_points_2(features, indices)
        res.backward(torch.ones_like(res))
    
    logging.basicConfig(filename='test.log', format='%(message)s', level=logging.INFO)
    logging.info(prof.key_averages().table(sort_by='self_cuda_memory_usage', row_limit=20))

index_points() and index_points_2() do the same thing, so they’re interchangeable.
The results are as follows. Only the relevant columns are showed here
index_points() result:


index_points_2() result:

From above image, we can know index_points() use more memory. But I do not konw why? Can anyone answer?