I am building a custom optimizer that samples learning rates from a Dirichlet distribution, whose parameters (alpha) need to be updated in each backpropagation. I’ve already figured out how to get the loss w.r.t. to these alpha parameters, effectively this would be ∂η/∂α, where η is the learning rate.
However, I need to “connect,” for lack of better word, this gradient with that of the loss, effectively ∂L/∂η, such that I can “chain” these gradients together, forming the expression:
∂L/∂η * ∂η/∂α = ∂L/∂α
I can then use this gradient to update the alphas and therefore improve the sampling of the distribution. The problem is I cannot figure out how to get ∂L/∂η. I’ve tried using the following line:
grad_learning_rate = torch.autograd.grad(loss, self.learning_rate, grad_outputs=torch.tensor(1.0, device=loss.device), retain_graph=True, allow_unused=True)[0]
where loss is passed into the optimizer after each forward pass. But the following error message is returned:
One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
I’ve attached the model:
class MLP(nn.Module):
def __init__(self, input_size, output_size, device: torch.device=None):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, 10, dtype=torch.float64)
self.relu = nn.ReLU()
# Fit model on cpu or inputted processing unit (xpu).
self.device = device if device is not None else torch.device('cpu')
self.to(self.device)
def forward(self, x):
x = self.fc1(x)
return x
And the Optimizer:
class Dart(Optimizer):
'''
Optimizer must receive losses throughout training.
'''
def __init__(self, params, betas=(0.9, 0.999),
alpha_init=1.0, alpha_lr=0.0001, eps=1e-8, weight_decay=0):
defaults = dict(betas=betas, eps=eps, weight_decay=weight_decay)
super(Dart, self).__init__(params, defaults)
self.alpha_scaler = alpha_init
self.alpha_lr = alpha_lr
self.learning_rate = None
self.alpha_grads = None
def sample_lr_candidates(self, mean=1e-3, std=1e-4, num_samples=(10, 1), min_lr=1e-6, max_lr=1e-1):
# Sample from a Gaussian distribution
lr_samples = torch.normal(mean=mean, std=std, size=(num_samples))
# Clip the values to ensure they are within the min_lr and max_lr range
lr_samples = torch.clamp(lr_samples, min=min_lr, max=max_lr)
return lr_samples.to(torch.float64)
def step(self, loss):
for group in self.param_groups: # only one group.
for p in group['params']:
if p.grad is None:
continue
dim = (10, 784) if p.shape == torch.Size([10, 784]) else (1, 10)
state = self.state[p] # optimizer class opens 'history' for param.
input = torch.empty(dim, device='cpu', dtype=torch.float64)
if len(state) == 0: # Initialize state if not already done.
state['step'] = 0
state['lr_candidates'] = self.sample_lr_candidates(num_samples=p.shape) # .to('xpu')
state['alphas'] = torch.ones_like(input, memory_format=torch.preserve_format) * self.alpha_scaler
state['step'] += 1
# Enable autograd for alpha updates
state['alphas'].requires_grad_(True)
# Sample from Dirichlet (WARNING: May not support autograd)
samples = torch.distributions.Dirichlet(state['alphas']).rsample() # .to('xpu') # Differentiable
total = state['alphas'].sum(-1, True).expand_as(state['alphas'])
grad_samples = torch._dirichlet_grad(samples, state['alphas'], total) # del p (samples) / del alphas
# Compute learning rate
# print(samples.shape, state['lr_candidates'].shape)
self.learning_rate = samples * state['lr_candidates'] # .to('xpu')
self.learning_rate.retain_grad()
# Compute gradient wrt alphas
grad_learning_rate = torch.autograd.grad(loss, self.learning_rate, grad_outputs=torch.tensor(1.0, device=loss.device), retain_graph=True)[0]
# Update alphas with gradient descent
state['alphas'] = state['alphas'] - self.alpha_lr * grad_samples * state['lr_candidates'] # need del L/ del n
self.alpha_grads = state['alphas']
# Apply weight update
# print(self.learning_rate.shape, p.grad.shape)
p.data.sub_(self.learning_rate.squeeze() * p.grad) # .to('xpu')