Excessive Memory Consumption in Forward Pass with Autograd

Hi. I am working with code from this paper where the authors compute Covariance Matrices as Kernels to be used in Kernel Regression. The functionality of the Conv2d and ReLU layer is changed such that a forward pass evaluates the Covariance Matrix instead of simply applying convolution or the ReLU function.

I am adapting this code to make it trainable with respect to the variances hence self.var_weight and self.var_bias are trainable parameters in each Conv2d layer while the ReLU layers have no trainable parameters. We only use convolution and ReLU layers wrapped in Sequential as such that calling the model on a training dataset executes the propagate method of each class.
For the Convolution

class Conv2d(NNGPKernel):
    def propagate(self, kp):
        kp = ConvKP(kp)
        if self.kernel_has_row_of_zeros:
            kernel = t.ones(1, 1, self.kernel_size+1, self.kernel_size+1)
            kernel[:, :, 0, :] = 0.
            kernel[:, :, :, 0] = 0.
        else:
            kernel = t.ones(1, 1, self.kernel_size, self.kernel_size)
        kernel = kernel * (self.var_weight / self.kernel_size**2)
        def f(patch):
            return (F.conv2d(patch, kernel, stride=self.stride, # CHANGE self.kernel to kernel
                             padding=self.padding, dilation=self.dilation)
                    + self.var_bias)
        return ConvKP(kp.same, kp.diag, f(kp.xy), f(kp.xx), f(kp.yy))

and for ReLU

class ReLU(NNGPKernel):
    """
    A ReLU nonlinearity, the covariance is numerically stabilised by clamping
    values.
    """
    f32_tiny = np.finfo(np.float32).tiny
    def propagate(self, kp):
        kp = NonlinKP(kp)
        """
        We need to calculate (xy, xx, yy == c, v₁, v₂):
                      ⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤
        √(v₁v₂) / 2π ⎷1 - c²/v₁v₂ + (π - θ)c / √(v₁v₂)

        which is equivalent to:
        1/2π ( √(v₁v₂ - c²) + (π - θ)c )

        # NOTE we divide by 2 to avoid multiplying the ReLU by sqrt(2)
        """
        xx_yy = kp.xx * kp.yy + self.f32_tiny

        # Clamp these so the outputs are not NaN
        # Use small eps to avoid NaN during backpropagation
        eps = 1e-6

        inverse_sqrt_xx_yy = 1 / (t.sqrt(xx_yy) + eps)
        cos_theta = (kp.xy * inverse_sqrt_xx_yy).clamp(-1+eps, 1-eps)

        sin_theta = t.sqrt((xx_yy - kp.xy**2).clamp(min=eps))

        theta = t.acos(cos_theta)
        xy = (sin_theta + (math.pi - theta)*kp.xy) / (2*math.pi)

        xx = kp.xx/2.
        if kp.same:
            yy = xx
            if kp.diag:
                xy = xx
            else:
                # Make sure the diagonal agrees with `xx`
                eye = t.eye(xy.size()[0]).unsqueeze(-1).unsqueeze(-1).to(kp.xy.device)
                xy = (1-eye)*xy + eye*xx
        else:
            yy = kp.yy/2.
        return NonlinKP(kp.same, kp.diag, xy, xx, yy)

this can all be wrapped in a custom Sequential class as below:

class Sequential(NNGPKernel):
    def __init__(self, *mods):
        super().__init__()
        self.mods = mods
        for idx, mod in enumerate(mods):
            self.add_module(str(idx), mod)
    def propagate(self, kp):
        for mod in self.mods:
            kp = mod.propagate(kp)

Note: The NNGPKernel is a child of the nn.Module that just adds functionality to the forward method so that the propagate methods are called and manages the tensor sizes and shapes

Now I can instantiate my model as for example

layers = []
for _ in range(7):  # n_layers
    layers += [
        Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2,
               var_bias=var_bias),
        ReLU(),
    ]

model = Sequential(
    *layers,
    Conv2d(kernel_size=28, padding=0, var_weight=var_weight,
           var_bias=var_bias),
)

and call model (X_train, X_train) which will give me a tensor of size (X_train.shape[0], X_train.shape[0]) where X_train is a tensor of size (#samples, 1, ImageHeight, ImageWidth).

The issue:

When I run this model for large sample sizes during training i run into serious memory issues. These are my memory results. Note that I am currently only working with a CPU and no GPUs.

Using the above architecture on the MNIST dataset, I consume about 2 GB of RAMfor a sample size of 100 images, 3.5GB for a sample size of 125 images, 5 GB for a sample size of 150 and this grows linearly. However when I try to compute the Kernel with torch.no_grad() context, all my memory issues go away and I can compute Kernels of very large sizes. This tells me that all of the memory is being consumed due to the autograd engine maintaining references to tensors that it needs to evaluate gradients in the backward pass of my training.

My question is:

Given the computation in each Conv2d and ReLU propagate function, is this memory usage acceptable for the given network size and sample size? Is there someway I can reduce the memory consumption? Given the current state I simply run out of RAM if I try to train with a sample size greater than 175 images with a PC with 16GB RAM. Or is the excessive memory usage coming from the elementwise operations in the ReLU layer which stores a lot of tensors and cannot be helped?