Best way to : split, process, merge

Hi,

I frequently encounter the situation where I have to split a tensor in non regular sub-parts then apply different operations on each part, and concatenate all the results together.

I found two approaches to do that and I wonder which one is better.
The following snippet of code attempts to present to the approaches.

def with_cat(input, dimensions, functions):
    outputs = []
    i = 0
    for dim, fn in zip(dimensions, functions):
        x = fn(input[:, i:i + dim])  # split and process
        outputs.append(x)
        i += dim

    return torch.cat(outputs, dim=1)  # merge


def without_cat(input, dimensions, functions):
    output = torch.empty_like(input)
    i = 0
    for dim, fn in zip(dimensions, functions):
        x = fn(input[:, i:i + dim])  # split and process
        output[:, i:i + dim] = x  # merge
        i += dim

    return output

note: both works similarly for forward and backward.

Which one is better ? In term of memory, execution time, idiomatic, …
Is there another better way to split process merge data ?

Thanks

1 Like

Which one is better ? In term of memory, execution time, idiomatic, …

I would call the second approach much more idiomatic, and it should be more efficient in terms of computational efficiency (“compute”) and memory consumption. The reason is that it wouldn’t make sense to use a Python list if you want something to be in a fixed-size array later on. A list would be super inefficient because it would allocate more memory on the fly when you append and it’s running out of allocated memory, and in the worst case, it can allocate approx twice as memory as needed. Btw. in your case, I would probably also specify the dtype in torch.empty_like.

If you can, the ideal way regarding computing and memory resources would be to modify/override the values in the input tensor directly instead of creating a second tensor (output), but it’s only possible if the input entries don’t overlap between the different computations you perform via fn

Thanks for the answer.

The reason is that it wouldn’t make sense to use a Python list if you want something to be in a fixed-size array later on.

Good point

A list would be super inefficient because it would allocate more memory on the fly when you append and it’s running out of allocated memory, and in the worst case, it can allocate approx twice as memory as needed.

The python list contains only references, so the problem you worry about will never happens.

I would probably also specify the dtype in torch.empty_like.

torch.empty_like does not only copy the size but also dtype and device.

the ideal way regarding computing and memory resources would be to modify/override the values in the input tensor directly

Good point, I will try to do that !

Oh yeah, I agree, so that wouldn’t really blow up memory, but it still would be less elegant I guess :stuck_out_tongue:

torch.empty_like does not only copy the size but also dtype and device.

Oh yeah, true. For some reason, I confused it with torch.empty and assumed it was initiated via dimension info, not an existing array.