How computation graph works in Pytorch ? Which specific operations are detached from the computation graph?

Hi, I have some doubts about how exactly the computation graph is constructed in Pytorch. I understand how the computation graph works for pre-defined layers. But what happens when I have some intermediate tensor manipulation operation between the layers?

For example, the below code (the exact model is not important), what is happening when I am reshaping things, slicing things, running a for loop, etc…

My “x” (output from the layers) completely changes after the tensor manipulation operations. How is this accounted for in the computation graph? I mean this is not part of a neural network. I am just doing these tasks for convenience (for example a custom flattening operation). But will gradients be calculated for these manipulation operations also?

 for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            emb = x
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout,
            if not i == self.num_layers - 1:
                x = self.lns[i](x)
        # "x" output from the previous layers

       # Intermediate tensor manipulation

        result = torch.empty((b_sz,4096)) 
        x= x.t()
        for i in range(b_sz):
            tensor = x[:, i*128:(i+1)*128]
            vector = tensor.reshape(1, -1)
            result[i, :] = vector.squeeze(0)

#  x is fed to other predefined layers

Autograd will track all operations on trainable parameters regardless if these are performed inside an nn.Module or as a plain PyTorch operation as long as they are differentiable.

PyTorch will store the previous x tensor and keep it as an intermediate forward activation needed for the backward call.

Here is a small code snippet which also shows the change in memory usage as the intermediate is not freed in the second case:

# 0

# plain tensor
x = torch.randn(1024, device="cuda")
# 4096

x = x + 1
# 4096

# leaf tensor creating a computation graph
x = torch.randn(1024,device="cuda", requires_grad=True)
# 4096

x = x + 1
# 8192
1 Like