Implementing a matrix with shared values and grads

Hi everyone,

I’m trying to implement a matrix with shared values, where there’s an underlying tensor of trainable_params that get embedded into the matrix at shared locations:

# I remember finding this implementation in one of
# these threads, but I lost the reference.
# This implementation is not entirely my original work.
from typing import Tuple

import torch
import torch.nn as nn


class WeightSharedMatrix(nn.Module):
    def __init__(
        self,
        positions: torch.Tensor,
        values: torch.Tensor,
        size: Tuple[int, int],
    ):
        super().__init__()
        self.positions = positions
        self.values = values
        self.size = size
        self.filler_constant = 0.0

        self.trainable_params = nn.Parameter(values)

        self.matrix = self._construct_matrix()

    def _construct_matrix(self):
        """
        Assuming that positions is a dictionary
        {
            (i, j): index_of_trainable_params
        }
        we construct a matrix with shared weights.
        """
        matrix = self.filler_constant * torch.ones(self.size)
        for (i, j), index in self.positions.items():
            matrix[i, j] = self.trainable_params[index]

        return matrix

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # A custom implementation of the multiplication between
        # x and the sparse matrix defined by positions and values
        return torch.matmul(self.matrix, x)

When I compute a forward and backward pass using it, it works well:

"""
A simple example: the weight-shared matrix is a 3x3 matrix
with only 2 shared parameters.

The matrix is defined as
[
    [a, 0, 0],
    [0, b, 0],
    [0, a, b],
]
"""
# We start by specifying the actual parameters we want to optimize.
a = 1.0
b = 2.0
values = torch.Tensor([a, b])

# These values will be shared at the following positions.
positions = {
    (0, 0): 0,
    (1, 1): 1,
    (2, 2): 1,
    (2, 1): 0,
}

# We can create a weight-shared matrix.
weight_shared_matrix = WeightSharedMatrix(positions, values, (3, 3))

# Let's test whether the gradients w.r.t. a and b are being computed correctly.
# We start with an input x = [1; 2; 3]
x = torch.Tensor([1, 2, 3]).reshape(3, 1)

# Let's check if the forward pass is correct,
# and the gradients are computed correctly.
# (so far, the forward pass is just W @ x,
# but we could modify it later if we want to)
y = weight_shared_matrix(x)

# The forward pass should be
# [a, 2b+3a, 3b].
assert torch.allclose(y.flatten(), torch.Tensor([a, 2 * b, 2 * a + 3 * b]))

# The gradients for e.g. sum(y) should be
# [3, 5]
y.sum().backward()
print(weight_shared_matrix.trainable_params.grad)
assert torch.allclose(
    weight_shared_matrix.trainable_params.grad, torch.Tensor([3, 5])
)

But, unfortunately, I’m running into a weird error when computing the loss a second time in a training setting. If I run

weight_shared_matrix = WeightSharedMatrix(positions, values, (3, 3))
x = torch.Tensor([1, 2, 3]).reshape(3, 1)

optimizer = torch.optim.AdamW(weight_shared_matrix.parameters(), lr=0.1)

for i in range(10):
    optimizer.zero_grad()
    y = weight_shared_matrix(x)
    loss = torch.sum(y)
    loss.backward()
    print(weight_shared_matrix.trainable_params)
    optimizer.step()

I get an error saying that we’re Trying to backward through the graph a second time in the second loop. How come? How come the optimizer.zero_grad() is not cleaning the grads of self.trainable_params?

1 Like

I realized I had a bug in my implementation. I thought that by assigning the trainable_params inside the matrix, they’d be passed by reference and be automatically updated after each backwards pass, but that’s not the case. We need to reconstruct the matrix at each forward pass.

A way to avoid the error I was getting is to reconstruct the matrix at each forward pass:

from typing import Tuple

import torch
import torch.nn as nn


class WeightSharedMatrix(nn.Module):
    def __init__(
        self,
        positions: torch.Tensor,
        values: torch.Tensor,
        size: Tuple[int, int],
    ):
        super().__init__()
        self.positions = positions
        self.values = values
        self.size = size
        self.filler_constant = 0.0

        self.trainable_params = nn.Parameter(values)

    def _construct_matrix(self, device: torch.device = None):
        """
        Assuming that positions is a dictionary
        {
            (i, j): index_of_trainable_params
        }
        we construct a matrix with shared weights.
        """
        matrix = self.filler_constant * torch.ones(self.size, device=device)
        for (i, j), index in self.positions.items():
            matrix[i, j] = self.trainable_params[index]

        return matrix

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        matrix = self._construct_matrix(x.device)
        return torch.matmul(matrix, x)