In meta-learning you want to differentiate through (inner) gradient updates themselves, for example to get the (outer) gradient of the validation loss wrt some hyperparameter.
I had issues with my outer gradients being nan
in mixed precision (regardless of the loss scaler value) so I made a toy example. I can’t reproduce the nan
outer gradient with it, but it exposes another issue, namely memory consumption being larger in AMD mode:
import torch
# torch.backends.cudnn.benchmark=True
DEVICE='cuda'
N, D = 80000, 800 #may need myltiple of 8
AMP=True
def get_data(N, D):
x_train = torch.randn((N, D), device=DEVICE)
x_val = torch.randn((N, D), device=DEVICE)
true_weights = torch.randn(D, device=DEVICE)
y_train = torch.matmul(x_train, true_weights)+ torch.randn(N, device=DEVICE)*0.05
y_val = torch.matmul(x_val, true_weights)+ torch.randn(N, device=DEVICE)*0.05
return x_train, y_train, x_val, y_val
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# with torch.cuda.amp.autocast(enabled=AMP):
x_train, y_train, x_val, y_val = get_data(N, D)
weights = torch.randn(D, requires_grad=True, device=DEVICE) # inner param to learn
regularizer = torch.full((D,), 5e-4, requires_grad=True, device=DEVICE) # outer param to learn
SCALER = 2**8 if AMP else 1
## Inner loop: learn weights
for i in range(10):
with torch.cuda.amp.autocast(enabled=AMP):
y_pred_train = torch.matmul(x_train, weights)
train_loss = torch.mean((y_train-y_pred_train)**2)
# with torch.cuda.amp.autocast(enabled=False):
inner_grads = torch.autograd.grad(SCALER*train_loss, weights, create_graph=True)[0]
inner_grads = (1/SCALER)*inner_grads
weights = weights - 0.1*inner_grads - regularizer*weights
# print(f'train loss {train_loss:.3g} -- inner grads min {torch.min(inner_grads):.3f} max {torch.max(inner_grads):.3f}')
print(f'memory allocated {float(torch.cuda.memory_allocated()) / (1024**3):.3g} GB')
## Outer step: learn regularizer
with torch.cuda.amp.autocast(enabled=AMP):
y_pred_val = torch.matmul(x_val, weights)
val_loss = SCALER * torch.mean((y_val - y_pred_val) ** 2)
val_loss.backward()
# print(f'---> outer grads {regularizer.grad*(1/SCALER)}')
When I run this with AMP=False
I get:
memory allocated 0.478 GB
memory allocated 0.479 GB
memory allocated 0.48 GB
memory allocated 0.48 GB
memory allocated 0.481 GB
memory allocated 0.481 GB
memory allocated 0.482 GB
memory allocated 0.483 GB
memory allocated 0.483 GB
memory allocated 0.484 GB
but when I run with AMP=True
I get
memory allocated 0.597 GB
memory allocated 0.717 GB
memory allocated 0.837 GB
memory allocated 0.957 GB
memory allocated 1.08 GB
memory allocated 1.2 GB
memory allocated 1.32 GB
memory allocated 1.44 GB
memory allocated 1.56 GB
memory allocated 1.68 GB
What is happening here?
System:
Windows 10
pytorch 1.7.1
Cuda 11.0
Cudnn8.0
Python 3.8
RTX 3070 laptop version