Inplace cat buffer

Hello!

We are training following architecture:

class AbstractDenseNet(torch.nn.Module):
    def __init__(
            self,
            in_features: int,
            depth: int,
            width: int
    ):
        super(AbstractDenseNet, self).__init__()
        self._depth = depth
        self._width = width
        self._in_features = in_features
        self._out_features = in_features + width * depth
        self.layers = torch.nn.ModuleList(
            [
                self.get_layer(in_features=in_features + i * width, out_features=width)
                for i in range(depth)
            ]
        )

    def get_layer(self, in_features: int, out_features: int) -> torch.nn.Module:
        raise NotImplementedError

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        cat_dim = len(tensor.size()) - 1
        for i in range(self._depth):
            cur_out = self.layers[i](tensor)
            tensor = torch.cat([tensor, cur_out], dim=cat_dim)
        return tensor

This code uses cat on each layer and blows on memory (initial tensor size is very big). We can’t train our network with sufficient depth on our GPU’s

Obvious (from algorithmic perspective) optimization here - use one big preallocated buffer and fill it inplace during forward pass. However, this does not work with autograd, which forbids inplace modifications.

How can we avoid this restriction?

Hi,

The thing is that at each iteration, you increase the input size. So your state (and layers) are bigger and bigger with the depth.
The input of the linear is needed to compute the backward pass. So you won’t be able to reuse a Tensor and override it with new values.

Currently, i am working on buffer solution which looks like this. Written in C++, because I got error message:

Returning Variables sharing storage with other Variables that require grad is not supported in Python functions. Please submit a feature request if you hit this error

However, when calling this function from C++ binding, i get same error message.

class BufferedCatFunction : public torch::autograd::Function<BufferedCatFunction> {
public:
    static Tensor forward(AutogradContext* ctx, Tensor buffer, std::vector<Tensor> tensors) {
        int64_t shape_1 = 0;
        std::vector<int64_t> sizes;
        for (const auto &tensor : tensors) {
            sizes.emplace_back(tensor.size(1));
            shape_1 += tensor.size(1);
        }
        int64_t shape_0 = tensors[0].size(0);
        auto result = buffer.narrow(0, 0, shape_0 * shape_1).view({shape_0, shape_1});
        result.set_requires_grad(true);
        torch::cat_out(result, tensors, 1);
        ctx->save_for_backward(tensors);
        return result;
    }

    static variable_list backward(
            AutogradContext* ctx,
            variable_list grad_output) {
        auto &grad = grad_output[0];
        auto saved = ctx->get_saved_variables();

        int64_t cur_offset = 0;
        torch::autograd::variable_list gradients;
        for (auto &tensor : saved) {
            auto size = tensor.size(1);
            gradients.emplace_back(grad.narrow(1, cur_offset, size));
            cur_offset += size;
        }

        return gradients;
    }
};


Tensor cat_inside_cpp_function(std::vector<Tensor> tensors, Tensor buffer) {
    return BufferedCatFunction::apply(buffer, tensors);
}

Here is python code:

if __name__ == '__main__':
    linears = [
        torch.nn.Linear(2, 1),
        torch.nn.Linear(3, 1),
        torch.nn.Linear(4, 1)
    ]

    with torch.autograd.detect_anomaly():

        inputs = torch.rand(3, 2)
        inputs.requires_grad = False
        concated1 = buffered_densenet_cpp.cat_inside_cpp([inputs], BUFFER)
        print(concated1)

        out1 = linears[0](concated1)
        concated2 = buffered_densenet_cpp.cat_inside_cpp([inputs, out1], BUFFER)
        print(concated2)

        concated2.sum().backward()

How can I avoid this shared storage error message and use buffer? All tensors, referring to same storage, are available (or i think they are available) in autograd and backward pass could be computed.

Hi,

The error message mentions python function because you were using python, but it is not specific to python. So moving to cpp won’t bring you any benefit. In general, for autograd related tasks, you can do everything from python.
Note though that this has been updated in master and this error does not exist anymore.

But your custom function is still missing calls to mark_dirty() as it is modifying some inputs inplace.
And because you modify input inplace, you need to return it as is but here you return a view of it. So that won’t work properly.

The problem here is that buffer is modified inplace by your function. But this buffer that is the input to the next Linear is also saved by that Linear because it needs its value for backward. So if you reuse the buffer again later, the previous linear won’t have access to the value it needs to compute the backward. And so you will go back to the original issue that you cannot change the buffer inplace because it is needed.

Hi,

Each cat operation is append-only modification. All concatenated tensors are already written in the buffer, and we just add a new (previous Linear) result to the right of buffer.

So, when you say " the previous linear won’t have access to the value it needs to compute the backward", do you mean that buffer for this linear will be changed in future (it won’t, as i said) or that it will magically “disappear” from autograd internals due to some code specifity?

Considering Python error - i’ll check out master branch and try it, thank you!

Each cat operation is append-only modification.

But the autograd cannot know that. You rewrite the whole buffer everytime. It just happen that the first part that you write is the same as what was already there?
Unfortunately we cannot detect that. If you write values in, we consider that the content has changed. We cannot check every value.