Gradient backpropagation though sliced tensors

I’m trying to build a model that trains Conv2d layers on the center crop of a larger image while using the same layers to produce a feature map from the full size image without calculating gradients. I can do this as follows:

x_crop = x[..., offset:-offset, offset:-offset]

x_crop = self.conv_layers(x_crop)

with torch.no_grad():
    x = self.conv_layers(x)

But I’d like to share batchnorm stats and/or reduce the redundant computation of the center patch.

To do this I believe I should be able to slice a center crop from the output feature map of the full image, then .detach() the full image feature map. This works though a single conv layer but the memory use jumps with more than one to about the same as backpropagating though the full feature map.

import torch
from torch import nn

class DualModel(nn.Module):
    def __init__(self, in_size, out_size, double_conv, full_input):
        super().__init__()
        self.double_conv = double_conv
        self.full_input = full_input
        self.in_size = in_size
        self.out_size = out_size
        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1, bias=True),
            nn.LeakyReLU(inplace=False),
            nn.Dropout2d(p=0.1),
        )
        self.conv_layer2 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1, bias=True),
            nn.LeakyReLU(inplace=False),
            nn.Dropout2d(p=0.1),
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.final = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=1, stride=1)
        
    def forward(self, x_c):
        offset = int((self.in_size - self.out_size)/2)
        x_c = self.conv_layer1(x_c)
        if self.double_conv:
            x_c = self.conv_layer2(x_c)
        if not self.full_input:
            x_c = x_c[..., offset:-offset, offset:-offset]
        return self.final(x_c)


def to_MiB(x):
    return x/1024**2

def estimate_memory_training(model, sample_input, optimizer_type=torch.optim.Adam, use_amp=False, device=0):
    """Predict the maximum memory usage of the model. 
    Args:
        optimizer_type (Type): the class name of the optimizer to instantiate
        model (nn.Module): the neural network model
        sample_input (torch.Tensor): A sample input to the network. It should be 
            a single item, not a batch, and it will be replicated batch_size times.
        batch_size (int): the batch size
        use_amp (bool): whether to estimate based on using mixed precision
        device (torch.device): the device to use
    """
    # Reset model and optimizer
    model.cpu()
    optimizer = optimizer_type(model.parameters(), lr=.001)
    a = torch.cuda.memory_allocated(device)
    model.to(device)
    b = torch.cuda.memory_allocated(device)
    model_memory = b - a
    model_input = sample_input
    output = model(model_input.to(device)).sum()
    c = torch.cuda.memory_allocated(device)
    if use_amp:
        amp_multiplier = .5
    else:
        amp_multiplier = 1
    forward_pass_memory = (c - b)*amp_multiplier
    gradient_memory = model_memory
    if isinstance(optimizer, torch.optim.Adam):
        o = 2
    elif isinstance(optimizer, torch.optim.RMSprop):
        o = 1
    elif isinstance(optimizer, torch.optim.SGD):
        o = 0
    elif isinstance(optimizer, torch.optim.Adagrad):
        o = 1
    else:
        raise ValueError("Unsupported optimizer. Look up how many moments are" +
            "stored by your optimizer and add a case to the optimizer checker.")
    gradient_moment_memory = o*gradient_memory
    total_memory = model_memory + forward_pass_memory + gradient_memory + gradient_moment_memory
    return total_memory

def test_memory_training(model, in_size, optimizer_type=torch.optim.SGD, use_amp=False, device=0):
    sample_input = torch.randn(*in_size, dtype=torch.float32)
    max_mem_est = estimate_memory_training(model, sample_input, optimizer_type=optimizer_type, use_amp=use_amp, device=device+1)
    print(f"Maximum Memory Estimate: {to_MiB(max_mem_est)} MiB")
    optimizer = optimizer_type(model.parameters(), lr=.001)
    print(f"Beginning mem: {to_MiB(torch.cuda.memory_allocated(device))}", "Note - this may be higher than 0, which is due to PyTorch caching. Don't worry too much about this number")
    model.to(device)
    print(f"After model to device: {to_MiB(torch.cuda.memory_allocated(device))}")
    for i in range(3):
        optimizer.zero_grad(set_to_none=True)
        print("Iteration", i)
        print(f"0 - After optimizer zero_grad: {to_MiB(torch.cuda.memory_allocated(device))}")
        with torch.cuda.amp.autocast(enabled=use_amp):
            a = torch.cuda.memory_allocated(device)
            out = model(sample_input.to(device)).sum() # Taking the sum here just to get a scalar output
            b = torch.cuda.memory_allocated(device)
        print(f"1 - After forward pass: {to_MiB(torch.cuda.memory_allocated(device))}")
        print(f"2 - Memory consumed by forward pass {to_MiB(b-a)}")
        out.backward()
        print(f"3 - After backward pass: {to_MiB(torch.cuda.memory_allocated(device))}")
        optimizer.step()
        print(f"4 - After optimizer step: {to_MiB(torch.cuda.memory_allocated(device))}")
        print(f"5 - Running max memory: {to_MiB(torch.cuda.max_memory_allocated(device))}")
        del out
        del sample_input
        sample_input = torch.randn(*in_size, dtype=torch.float32)


if __name__ == '__main__':
    model = DualModel(1024, 512, double_conv=False, full_input=False)
    test_memory_training(model, in_size=(4,1,1024,1024), use_amp=False, device=0)

Is there something I’m missing here? Anyway to make this work?
Thanks

I’m not sure where exactly you are detaching the output in your code snippet, but I think slicing the output should work as it would only calculate the gradients of the used region as seen here:

conv_layer = nn.Conv2d(1, 1, 3, 1, 1)
x = torch.randn(1, 1, 7, 7, requires_grad=True)
offset = 3

out = conv_layer(x)
out_crop = out[:, :, offset:-offset, offset:-offset]

out_crop.mean().backward()
print(x.grad)
# tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
#           [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
#           [ 0.0000,  0.0000,  0.1436,  0.1572,  0.2688,  0.0000,  0.0000],
#           [ 0.0000,  0.0000, -0.0043, -0.1903, -0.0190,  0.0000,  0.0000],
#           [ 0.0000,  0.0000, -0.1737,  0.2930,  0.0941,  0.0000,  0.0000],
#           [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
#           [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]]])

Sorry for the confusion - the .detatch() is to keep using it in a larger model but isn’t necessary in the code snippet.

I’ve figured out what was confusing me, Thanks! it is only calculating the gradients in the relevant areas as you suggested but I was expecting to have the memory use to match a sliced input tensor and while slicing the out layer lowers the memory use the grad tensors are still full size! (if full of zeros)

image

I was hoping for numbers closer to the slice input tensor values (doing the slice before conv).

Another quick question if you don’t mind? (or shall I make a new topic?) If I have a layer containing batchnorm2d that I run twice - once with torch.no_grad(), once without: can I use the batch stats from one of the calls for the other ie freeze them temporarily?

You could call .eval() on the batchnorm layer, which will use the running stats to normalize the input activations and will not update them using the batch stats as would be the case in .train() mode.
Would this fit your use case or do you still want to normalize the input activation using the input stats, but still don’t want to update the running stats in one of these forward passes?

Since one of the passes is on a larger patch (no_grad wrapped) and the other is on a smaller crop from that patch I want to use the input stats from the larger patch for both passes (mimicking slicing the output of the larger patch after but with a lower backprop memory footprint). I’m suspecting that this will require a custom implementation of the norm layer though? Running the smaller patch forward pass with the running stats might suffice though if that’s as simple as setting .eval()

Yes, you might need to use a custom batchnorm implementation if you want to provide the mean and std of the larger input to normalize the smaller input activation.

Note that the running stats are updated in each forward pass and might not be “converged” yet, so let me know if this works for you.