I have discovered that I am severing the computational graph when constructing a parameterized tensor:
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.A_params = nn.Parameter(torch.tensor([0.3, 0.1]), requires_grad=True)
self.update_ss()
def update_ss(self):
k_a, k_b = torch.sigmoid(self.A_params) # constrain to [0, 1]
self.A = torch.tensor([[-k_a, 0], [k_a, -k_b]]) # does not preserve gradient??
How do I construct my tensor (A) while preserving the gradient from A_params to A