KPConv custom gather operation creating large memory requirements on backward?

I am working on implementing my own use case for a KPConv KPFCNN network (Original article from Huges et. al.).

After some debugging I have found that during some backward passes of a KPFCNN model (defined here ~line 189) the custom gather operation (defined here ~line 35), based on model input size, is always the culprit of an OOM exception.

The gather operation

def gather(x, idx, method=2):
    """
    implementation of a custom gather operation for faster backwards.
    :param x: input with shape [N, D_1, ... D_d]
    :param idx: indexing with shape [n_1, ..., n_m]
    :param method: Choice of the method
    :return: x[idx] with shape [n_1, ..., n_m, D_1, ... D_d]
    """

    if method == 0:
        return x[idx]
    elif method == 1:
        x = x.unsqueeze(1)
        x = x.expand((-1, idx.shape[-1], -1))
        idx = idx.unsqueeze(2)
        idx = idx.expand((-1, -1, x.shape[-1]))
        return x.gather(0, idx)
    elif method == 2:
        for i, ni in enumerate(idx.size()[1:]):
            x = x.unsqueeze(i+1)
            new_s = list(x.size())
            new_s[i+1] = ni
            x = x.expand(new_s)
        n = len(idx.size())
        for i, di in enumerate(x.size()[n:]):
            idx = idx.unsqueeze(i+n)
            new_s = list(idx.size())
            new_s[i+n] = di
            idx = idx.expand(new_s)
        return x.gather(0, idx)
    else:
        raise ValueError('Unkown method')

Based on the code comments this is a custom implementation of torch.gather that provides some execution speed-up. I am attempting to understand/determine if this custom operation is truly the culprit of the OOM I am experiencing, though I am not exactly well versed enough just yet to fully understand the memory requirements of the custom operation during a backward pass.

I am currently attempting to replace the use of this operation with torch.gather to determine if any decrease in memory usage is possible. But until then, or if that does not provide a clear answer, is it possible that this custom gather operation has a larger, or much larger, memory requirement than torch.gather?

This could be the case and you could profile both methods by running the forward and backward operation inside a small fake model and check the memory usage via torch.cuda.memory_summary().