Hi everyone,
I’m trying to implement a matrix with shared values, where there’s an underlying tensor of trainable_params
that get embedded into the matrix at shared locations:
# I remember finding this implementation in one of
# these threads, but I lost the reference.
# This implementation is not entirely my original work.
from typing import Tuple
import torch
import torch.nn as nn
class WeightSharedMatrix(nn.Module):
def __init__(
self,
positions: torch.Tensor,
values: torch.Tensor,
size: Tuple[int, int],
):
super().__init__()
self.positions = positions
self.values = values
self.size = size
self.filler_constant = 0.0
self.trainable_params = nn.Parameter(values)
self.matrix = self._construct_matrix()
def _construct_matrix(self):
"""
Assuming that positions is a dictionary
{
(i, j): index_of_trainable_params
}
we construct a matrix with shared weights.
"""
matrix = self.filler_constant * torch.ones(self.size)
for (i, j), index in self.positions.items():
matrix[i, j] = self.trainable_params[index]
return matrix
def forward(self, x: torch.Tensor) -> torch.Tensor:
# A custom implementation of the multiplication between
# x and the sparse matrix defined by positions and values
return torch.matmul(self.matrix, x)
When I compute a forward and backward pass using it, it works well:
"""
A simple example: the weight-shared matrix is a 3x3 matrix
with only 2 shared parameters.
The matrix is defined as
[
[a, 0, 0],
[0, b, 0],
[0, a, b],
]
"""
# We start by specifying the actual parameters we want to optimize.
a = 1.0
b = 2.0
values = torch.Tensor([a, b])
# These values will be shared at the following positions.
positions = {
(0, 0): 0,
(1, 1): 1,
(2, 2): 1,
(2, 1): 0,
}
# We can create a weight-shared matrix.
weight_shared_matrix = WeightSharedMatrix(positions, values, (3, 3))
# Let's test whether the gradients w.r.t. a and b are being computed correctly.
# We start with an input x = [1; 2; 3]
x = torch.Tensor([1, 2, 3]).reshape(3, 1)
# Let's check if the forward pass is correct,
# and the gradients are computed correctly.
# (so far, the forward pass is just W @ x,
# but we could modify it later if we want to)
y = weight_shared_matrix(x)
# The forward pass should be
# [a, 2b+3a, 3b].
assert torch.allclose(y.flatten(), torch.Tensor([a, 2 * b, 2 * a + 3 * b]))
# The gradients for e.g. sum(y) should be
# [3, 5]
y.sum().backward()
print(weight_shared_matrix.trainable_params.grad)
assert torch.allclose(
weight_shared_matrix.trainable_params.grad, torch.Tensor([3, 5])
)
But, unfortunately, I’m running into a weird error when computing the loss a second time in a training setting. If I run
weight_shared_matrix = WeightSharedMatrix(positions, values, (3, 3))
x = torch.Tensor([1, 2, 3]).reshape(3, 1)
optimizer = torch.optim.AdamW(weight_shared_matrix.parameters(), lr=0.1)
for i in range(10):
optimizer.zero_grad()
y = weight_shared_matrix(x)
loss = torch.sum(y)
loss.backward()
print(weight_shared_matrix.trainable_params)
optimizer.step()
I get an error saying that we’re Trying to backward through the graph a second time
in the second loop. How come? How come the optimizer.zero_grad()
is not cleaning the grads of self.trainable_params
?