Backward Hooks with AMP

HI, I tried looking in the documentation but I couldn’t find what I’m looking for.
My question is: how are intermediate gradient collected by the backward hooks affected by amp?

I understand that during training we should use GradScaler on the loss as GradScaler.scale(loss).backward(), but what about the gradient collected by a backward hook? Should that also be scaled?

EDIT:
I tried doing some testing with a simple setup and found out that if amp is NOT used I obtain different gradients from when amp IS used e.g.

tensor([[-0.6644,  0.6644]], device='cuda:0') # normal
tensor([[-43552.,  43552.]], device='cuda:0', dtype=torch.float16) # amp

but using GradScaler.scale on the gradient did not return the non-amp one. Is there a way to use amp and hooks to gather intermediate gradients?

Here the main parts of my testing setup
Model:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5, bias=True)
        self.r1 = nn.ReLU()
        self.fc2 = nn.Linear(5, 2, bias=True)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.r1(x)
        x = self.fc2(x)
        return x

Hook:

def backward_hook(module, grad_input, grad_output):
    print(grad_output[0])

Training:

dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, shuffle=False, batch_size=1, num_workers=0)

for data, target in dataloader:
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    if amp:
        with autocast():
            out = model(data)
            loss = loss_fn(out, target)
        scaled_loss = scaler.scale(loss)
        scaled_loss.backward()
            
        scaler.step(optimizer)
        scaler.update()
    else:
        out = model(data)
        loss = loss_fn(out, target)
        loss.backward()
        optimizer.step()
1 Like

If you want to work with unscaled gradients, you could follow the AMP - Working with unscaled gradients tutorial, which explains how to unscale the gradients before using them in e.g. gradient clipping.

Hi, thanks for the reply, but unless I misunderstood something this isn’t what I’m looking for: in the link you provided it is suggested to unscale the optimizer so that we can unscales gradients held by optimizer’s assigned parameters, but those are not the gradient I’m interested in. I’ll try to explaing it better.

In my project I need to use the intermediate gradients of the neurons (NOT the gradient of the loss w.r.t. the weights) and to fetch those I employ some backward hooks.
When I run the backward() pass WITHOUT amp I obtain the following grad tensor tensor([[-0.6644, 0.6644]], device='cuda:0').
When I run the backward() pass WITH amp, instead, I get tensor([[-43552., 43552.]], device='cuda:0', dtype=torch.float16). Now I’d like to unscale this float16 gradients to the values I get without amp, so that I can still use them in the application. Is this possible?

Yes, the next section of the tutorials shows how to unscale gradients manually using inv_scale = 1./scaler.get_scale(), which you could also use in your hooks.

2 Likes

Oh wow, I really missed that. Thank you very much for helping me.

1 Like

If performance is a concern, prefer 1./scaler._get_scale_async(), which will be faster because it retrieves the scale factor tensor directly. get_scale() calls .item() on the scale factor tensor which incurs a cpu->gpu memcopy and sync. I should probably give _get_scale_async() a documented, public exposure.

3 Likes