Why is grad_sample rescaled by the batch size?

The pytorch .grad attribute stores the sum of the gradients of a batch.
To my understanding the Opacus .grad_sample attribute is supposed to store each invidual gradient of a batch. However, it stores the invidual gradients rescaled by the batch size. This is known and it is stated in 2a in an official Opacus tutorial

On the other hand in the tutorial it sounds that .grad_sample stores the actual invidual gradients:

The above grad_sampler takes in the activations and backpropagated gradients, computes the per-sample-gradients with respect to the module parameters, and maps them to the corresponding parameters.

I realize that this a very basic question, but I am just confused about what the object .grad_sample is actually storing and what I expected it to store according to the description by Opacus.

Is there a reason why grad_sample stores rescaled invidual gradients?

Hey homunkulus, thanks for your interest!

You’re right to state that the grad_sample attribute represents the per-sample gradient. It is not rescaled.

However, after clipping and noising, these per-sample gradients (successively stored in the summed_grad then on the grad attribute of the paramter are aggregated (often averaged, see here). This may be the reason for your confusion.

Do not hesitate to reach out for further help.

Thanks,
Pierre

Hi Pierre,

thanks a lot for your answer.
In your answer you are stating that the gradients are not rescaled. However, I obtain rescaled gradients. I have attached some code, where I rescale all the elements in the per_sample_grad and I compare these elements to the elements obtained from the grad attribute. They coincide.

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from copy import deepcopy
from opacus import GradSampleModule

import warnings
warnings.simplefilter("ignore")

class SampleNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(16, 8)
        self.fc2 = nn.Linear(8, 2)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

DATASET = TensorDataset(torch.randn(10, 16), torch.randint(0,2, (10,)))

model = SampleNet()
model_opacus = GradSampleModule(deepcopy(model))

# calculate gradients with opacus
dl_opacus = DataLoader(DATASET, 7)
per_sample_grads = []
grads_sum        = []
grads_opacus     = []
for X, Y in dl_opacus:
    batch_size = len(Y)
    loss = model_opacus(X).sum()
    model_opacus.zero_grad()
    model_opacus.train()
    loss.backward()
    per_sample_grad = [p.grad_sample.detach().clone() for p in model_opacus.parameters()]
    per_sample_grads.append(per_sample_grad)
    grad_sum = [p.grad.detach().clone() for p in model_opacus.parameters()]
    grads_sum.append(grad_sum)
    assert torch.allclose(per_sample_grad[0].mean(0), grad_sum[0]), "Per sample grads of 1st layer are different from mean grads"
    assert torch.allclose(per_sample_grad[1].mean(0), grad_sum[1]), "Per sample grads of 2nd layer are different from mean grads"
    
    # calculate single gradients from per_sample gradients (rescaled by batch_size)
    for i in range(batch_size):
        v_i =[]
        for p in model_opacus.parameters():
            v_i.append(p.grad_sample[i] / batch_size) # Here I rescale all gradients from grad_sample
        grads_opacus.append(v_i)

# calculate gradients without opacus
dl = DataLoader(DATASET, batch_size = 1)
grads = []
for x, y in dl:
    loss = model(x).sum()
    model.zero_grad()
    loss.backward()
    grad = [p.grad.detach().clone() for p in model.parameters()]
    grads.append(grad)
    
# test if (rescaled) gradients from opacus and gradients (without opacus) are equal
for i in range(len(grads)):
    grads_opacus_layer, grads_layer = grads[i], grads_opacus[i]
    # iterate throught the components of the gradient layerwise
    for j in range(len(grads_layer)):
        grad_opacus = grads_opacus_layer[j]
        grad        = grads_layer[j]

        assert torch.allclose(grad_opacus, grad, atol = 1e-6), "Per sample-grads and grads are different"

Hi @homunkulus,
Note that the GradSampleModule takes as argument loss_reduction, which is equal to ‘mean’ by default. In your case, you do a sum on the loss but the GradSampleModule thinks you’re doing a mean and is thus multiplying back the backprops.

If you add loss_reduction='mean' to GradSampleModule(deepcopy(model)), you should obtain correct per_sample_grad without the need to rescale.

Hi Alexandre,
thanks a lot for you answer! This solves my confusion.
Thanks to your answer I was looking at the right spot in the source code of the grad_sample_module and found

        n = module.max_batch_len
        if loss_reduction == "mean":
            backprops = backprops * n
        elif loss_reduction == "sum":
            backprops = backprops

which shows that I have to choose loss_reduction = "sum" in the GradSampleModule to obtain the non-rescaled gradients in grad_sample.