Concatenate tensors without memory copying

Hi,

I’m wondering if there is any alternative concatenation method that concatenate two tensor without memory copying?

Currently, I use t = torch.cat([t1, t2], dim=0) in my data pre-processing. However, I got the out-of-memory error because there are many big tensors need to be concatenated.

I have searched around and read some threads like tensor appending, Torch.cat() blows up memory required. But still cannot find a desirable solutions to solve the memory consuming problem.

1 Like

One way is to allocate tensor with target dimension and then assign using slice operator

Following is an example,

t1_shape = t1.shape[0]
t2_shape = t2.shape[0]
second_dim_shape = t1.shape[1]
t = torch.zeros((t1_shape + t2_shape, second_dim))
t[:t1_shape,:] = t1
t[t1_shape:,:] = t2

I know that this is a very simple case and will get complicate for nd tensor.
I am not aware of efficient internal api for the same

1 Like

Hi,

Since the two original tensors t1 and t2 are at different places in memory, it’s not possible to make a single Tensor out of them without creating a new tensor that can contain both of them.
The cat implementation does pretty much what the code sample above from @bhushans23 does: Create a Tensor that can contain everything then copy each part into it.

3 Likes

Ok, so it is inevitable to allocate new memory when concatenate two tensors in pytorch right now.

Is it possible to implement a new concatenation operation like this post in pytorch? It records the reference instead of memory copying.

This new feature should be pretty useful I think.

Hi,

The post you link does not implement a new concatenation op. It just have a helper function that caries the pair of tensors around and do the ops on each. You can already do that by just using the list with your two Tensors and then adapt the next function to do on it. Did I missed something in that issue?

Another solution is to pre-allocate the full tensor and compute t1 and t2 directly into it doing inplace operations. That way you don’t need the cat operation at all.

1 Like

Hi, AlbanD

In that post, the concatenation op doesn’t allocate new memory. It maintains a pointer table which points to the shared memory storage.

The simple solution you suggest below won’t work generally (e.g. the required input type is a tensor rather than a list or I want to concatenate two tensors along with different dimensions).

You can already do that by just using the list with your two Tensors

To clarify, let me use a simple example to explain what I want.

Suppose now we concatenate two tensor through below code

t1 = torch.randn(512, 256, 100, 100)
t2 = torch.randn(512, 256, 100, 100)
t = torch.cat(t1, t2, dim=1) 

The total memory consuming here will be 512x256x100x100x4 number of float32. Besides, simply list t = [t1, t2] is incorrect.

Is it possible to implement a memory efficient concatenation like

t1 = torch.randn(512, 256, 100, 100)
t2 = torch.randn(512, 256, 100, 100)
t_efficient = torch.cat(t1, t2, dim=1, allocation="shared")

The variable t_efficient just records the memory reference of t1 and t2 rather than allocating new memory, and the total memory consuming should be 512x256x100x100x2.

Hi,

No we don’t have this feature.

What I meant is what this feature does is just keep a list of both Tensors and adapting the few ops you will need to do on them.
A quick example below:

import torch

class MySharedTensor(object):
    def __init__(self, tensors, dim=None):
        assert dim is not None
        self.dim = dim
        assert (isinstance(tensors, list))
        assert (torch.is_tensor(t) for t in tensors)
        self.tensors = tensors
        self.dim_sizes = [t.size(self.dim) for t in self.tensors]

    # If you want to recover a single Tensor from it
    def to_full_tensor(self):
        return torch.cat(self.tensors, dim=self.dim)

    # Out of place addition
    def __add__(self, other):
        assert torch.is_tensor(other)
        out = other.clone()
        curr_idx = 0
        for i, t in enumerate(self.tensors):
            other_slice = other.narrow(self.dim, curr_idx, self.dim_sizes[i])
            out.narrow(self.dim, curr_idx, self.dim_sizes[i]).copy_(t).add_(other_slice)
            curr_idx += self.dim_sizes[i]
        return out

    # Inplace add
    def __iadd__(self, other):
        assert torch.is_tensor(other)
        curr_idx = 0
        for i, t in enumerate(self.tensors):
            other_slice = other.narrow(self.dim, curr_idx, self.dim_sizes[i])
            t.add_(other_slice)
            curr_idx += self.dim_sizes[i]
        return self

    # Matrix Multiplication (only 2d matrices for simplicity)
    def mm(self, other):
        assert other.ndimension() == 2
        assert all(t.ndimension() == 2 for t in self.tensors)

        if self.dim == 0:
            out_tensors = []
            for t in self.tensors:
                out_tensors.append(t.mm(other))
            return MySharedTensor(out_tensors, dim=0)
        elif self.dim == 1:
            out = 0
            curr_idx = 0
            for i, t in enumerate(self.tensors):
                other_slice = other.narrow(0, curr_idx, self.dim_sizes[i])
                out += t.mm(other_slice)
                curr_idx += self.dim_sizes[i]
            return out
        else:
            raise RuntimeError("Invalid dimension")


a = torch.rand(2, 4)
b = torch.rand(2, 4)
print("a and b")
print(a)
print(b)

c = MySharedTensor([a, b], dim=1)
print("c size:")
print(c.to_full_tensor().size())

d = torch.rand(2, 8)
print("d")
print(d)

e = c + d
print("e = c + d")
print(e)

f = c.to_full_tensor() + d
print("f = c.to_full_tensor() + d")
print(f)

c += d
print("c += d")
print(c.to_full_tensor())
print("a and b are changed:")
print(a, b)


g = torch.rand(8, 3)
print("g")
print(g)

h = c.mm(g)
print("h = c.mm(g)")
print(h)

k = c.to_full_tensor().mm(g)
print("k = c.to_full_tensor().mm(g)")
print(k)
1 Like

The MySharedTensor is a temporary solution because it only works in a few of ops and cannot seamlessly adapt to all ops that the native Tensor supports. This is the reason I said it won’t work generally.

I hope the pytorch team can add this new feature in the future.

By the way, I want MySharedTensor can work with conv2d without allocating new memory. What should I do?

The thing is that supporting such thing is a very large change as each op needs to be reimplemented for it. In particular, if you don’t want to ever allocate the full Tensor, it can be tricky to use libraries like cudnn that do not support these features without a speed drop.
Do you know how tensorflow handle such cases? What is the performance drop of using such structure?

For the case of conv2d, that will depend on which dimension you concatenate over.
If it’s 0, then you can do a for-loop passing each Tensor one by one and then concatenating the outputs along dimension 0.
If it’s 1, then you will need 2 smaller convs (if you have 2 input Tensors) that take a subset of the input channels and the same number of output channels and sum the output of each conv.
If it’s 2 or 3 then it gets quite tricky to do as the interface between the two is shared and you cannot do two independant conv.

Sorry, I don’t know how tensorflow deal with this case. But it seems a technique report Memory-Efficient Implementation of DenseNets has already implemented this feature in pytorch.

I will concatenate only in dimension 0 and 1. Will try this trick. Thanks for your help.

2 Likes

For the newcomers, the implementation of DenseNets uses checkpoint, where intermediate features (concatenation) are recovered by re-forwarding for each segment while back-warding.

Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can cause persistent states like the RNG state to be advanced than they would without checkpointing

By the way this is coming to pytorch :wink:
See this issue: https://github.com/pytorch/pytorch/issues/22169

But what if I can guarantee that concatenated tensors are views of the same tensor. Imagine I have a tensor

t = torch.randn(400)

and I need to calculate the mean of every 100 elements. In order to avoid any loops, I can easily do something like this:

means = t.view(4, -1).mean(dim=1)

but let’s say I need to calculate mean() of intersecting segments, in particular:

means = torch.cat((t[0:200], t[100:300], t[200:400]), dim=0).mean(dim=1)

Can I do it efficiently, without tensor copy happening in torch.cat()?

You can in some cases yes, but it won’t be the most easy-to-read code:

import torch

a = torch.rand(40)

print(a)

size_slices = 20
size_overlap = 10
nb_slices = 3
b = a.as_strided((nb_slices, size_slices), (size_slices-size_overlap, 1), 0)
print(b)

print(b.sum(0))

Hi guys,
I have been tracking this topic for a long time, but still couldn’t find a good solution.
So, I tried to implement a customized pytorch cuda extension to totally avoid memory copy on cuda device. However, my solution is very specific in use case.

In my case, I combined Concat + Conv2d into a single CatConv2d kernel, which can significantly reduce the latency in some cases (small batch, small in/out channels).

for example, the operations of concat of a list of Tensor followed by a conv2d before, now can be done in a single operator with single cuda kernel:

    x_list = [x0, x1, x2, x3,...]
    x = torch.cat(x_list, 1)
    x = conv.forward(x)

    # equivalent to 
    x = CatConv2d(x_list)

Here is the repo:

The backward path is still missing (forward only) in this repo, and there are many limitations, but I’m just trying to give a sense that concat+some_op combination is very promising in terms of performance and worth to work on.

I ran your code on colab,and got the following error
cat() received an invalid combination of arguments - got (Tensor, Tensor, dim=int), but expected one of:

  • (tuple of Tensors tensors, int dim, *, Tensor out)
  • (tuple of Tensors tensors, name dim, *, Tensor out)

Pass the tensors as a tuple or list as described in the error message:

t = torch.cat((t1, t2), dim=1) 
t = torch.cat([t1, t2], dim=1) 
1 Like

I am wondering if this concatenate without copying could be done in C++?
How to concatenate tensors without copy? - C++ - PyTorch Forums