Optimizing Memory Allocation of Chunk/Cat Operators

I’m trying to compute a Fast Walsh-Hadamard function. In theory this operation should allocate 0 extra memory, but the implementation below is not.

def fast_walsh_hadamard_torched(x, axis: int = 0, normalize: bool = False):
    orig_shape = x.size()
    assert axis >= 0 and axis < len(orig_shape), (
        "For a vector of shape %s, axis must be in [0, %d] but it is %d"
        % (orig_shape, len(orig_shape) - 1, axis)
    )
    h_dim = orig_shape[axis]
    h_dim_exp = int(round(np.log(h_dim) / np.log(2)))
    assert h_dim == 2 ** h_dim_exp, (
        "hadamard can only be computed over axis with size that is a power of two, but"
        " chosen axis %d has size %d" % (axis, h_dim)
    )

    working_shape_pre = [int(torch.prod(torch.tensor(orig_shape[:axis])))]
    working_shape_post = [
        int(torch.prod(torch.tensor(orig_shape[axis + 1:])))
    ]  
    working_shape_mid = [2] * h_dim_exp
    working_shape = working_shape_pre + working_shape_mid + working_shape_post

    ret = x.view(working_shape)

    print(ret.size())
    for dim in range(1, h_dim_exp+1):
        left, right = torch.chunk(ret, 2, dim=dim)
        ret = torch.cat([left + right, left - right], dim=dim)
        del left, right
    if normalize:
        ret = ret / torch.sqrt(float(h_dim))

    ret = ret.view(orig_shape)

    return ret
torch.Size([67108864])
-----------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                     Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     CPU Mem          Self CPU Mem     CUDA Mem         Self CUDA Mem    Number of Calls
-----------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
empty                    0.09%            3.494ms          0.09%            3.494ms          20.315us         13.50 Gb         13.50 Gb         0 b              0 b              172
resize_                  0.04%            1.461ms          0.04%            1.461ms          27.565us         13.00 Gb         13.00 Gb         0 b              0 b              53
empty_strided            0.00%            76.222us         0.00%            76.222us         38.111us         512.00 Mb        512.00 Mb        0 b              0 b              2
zeros                    0.00%            24.376us         0.00%            66.824us         66.824us         400 b            0 b              0 b              0 b              1
zero_                    0.00%            9.836us          0.00%            30.566us         30.566us         0 b              0 b              0 b              0 b              1
fill_                    0.38%            15.447ms         0.38%            15.450ms         2.575ms          0 b              0 b              0 b              0 b              6
to                       0.00%            48.789us         1.14%            46.184ms         5.132ms          512.00 Mb        0 b              0 b              0 b              9
size                     0.07%            2.735ms          0.07%            2.735ms          0.453us          0 b              0 b              0 b              0 b              6036
constant_pad_nd          0.00%            120.489us        0.38%            15.590ms         15.590ms         256.00 Mb        0 b              0 b              0 b              1
narrow                   0.02%            664.211us        0.06%            2.371ms          15.099us         0 b              0 b              0 b              0 b              157
-----------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------

As we can see the implementation is allocating a lot more memory than is needed, I narrowed this down to the cat operation in the for loop of the implementation. Essentially cat is allocating new tensors every time it’s being called (and is being recorded as empty op in the profiler).

I also need gradient history to be recorded so I’m not able to just use a buffer and use the out functionality of cat.

Any ideas?