Hi, I am trying to learn a scalar parameter to define a matrix, so I have written a custom autograd function that allows me to do the type casting, since that is non-differentiable. Autograd calls the backward function, so there is a gradient, but the parameter itself doesn’t update (tried different learning rates already).
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.prob = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float32), requires_grad=True)
def forward(self, x):
output = quantize_length(self.prob).cuda()
return output
def linear_length(prob):
# [0, 8] length in frames = 9 possible options
# input within range [0 1], map to an integer length
return torch.floor(9 * prob).type(torch.uint8)
class QuantizeLengthFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, length_prob):
ctx.save_for_backward(length_prob)
output = linear_length(length_prob) # quantize length
weight = torch.zeros((10, 10), dtype=torch.float32)
weight[:, :output] = torch.ones((10, output), dtype=torch.float32).cuda()
return weight
def backward(ctx, grad_output):
length_prob = ctx.saved_tensors
grad_input = grad_output.clone().sum() * 9.0
return grad_input
quantize_length = QuantizeLengthFunction.apply
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 15, 10, 10
# Create random Tensors to hold input and outputs.
x = torch.rand(N, D_in, device=device, dtype=dtype) - 0.5
y = torch.rand(N, H, device=device, dtype=dtype) - 0.5
# Create random Tensors for weights.
w1 = torch.rand(D_in, H, device=device, dtype=dtype, requires_grad=True)
model = ExampleModel()
for name, data in model.named_parameters():
print(name)
optim = torch.optim.Adam(model.parameters(), lr=1e-4, eps=1e-8)
for t in range(500):
output = model(x)
y_pred = (x.mm(w1)).mm(output)
loss = (y_pred - y).pow(2).sum()
# Use autograd to compute the backward pass.
optim.zero_grad()
loss.backward()
optim.step()
If instead of using autograd to compute the backward pass, I use the following then the update works:
# # Update weights using gradient descent
with torch.no_grad():
w1 -= learning_rate * w1.grad
prob -= learning_rate * prob.grad
# Manually zero the gradients after updating weights
w1.grad.zero_()
prob.grad.zero_()
But I’m trying to put this in a much larger network so I’d like to use Autograd to update this scalar along with the rest of my network. What am I doing wrong here?