Don't register parameter

I want to train an autoencoder, where the encoder and decoder get created with some complex rules.
I was wondering why I cannot calculate the encoder matrix from stacked parameters in init, but instead need to do it in forward() (see commented lines that don’t work). Are the combined parameters somehow registered, or did I misunderstand something about how the computational graph is created?

Replacing the

encoder_weights = torch.stack([param for param in self.encoder_params], dim=-1)
decoder_weights = torch.stack([param for param in self.decoder_params], dim=-1).T

with

encoder_weights = object.__getattribute__(self, "encoder_weights")
decoder_weights = object.__getattribute__(self, "decoder_weights")

and I don’t understand why it breaks something.

Please tell me if I should create a reduced sample which reproduces my problem.

import torch
import torch.nn as nn
from scipy.spatial.transform import Rotation as R
from torch.utils.data import IterableDataset


def create_mapping(output_joint_names,
                   num_latent_points):
    output_dim = len(output_joint_names)

    all_parameters = [None] * output_dim

    for i, output_joint_name in enumerate(output_joint_names):
        param_free_latent_i = torch.nn.Parameter(torch.rand(size=(num_latent_points,)) / 10 + 0.5)
        all_parameters[i] = param_free_latent_i

    all_parameters = torch.nn.ParameterList(all_parameters)

    return all_parameters


class Autoencoder(torch.nn.Module):
    def __init__(self,
                 output_joint_names,
                 num_latent_points
                 ):
        super().__init__()

        encoder_params = create_mapping(output_joint_names,
                                        num_latent_points)

        decoder_params = create_mapping(output_joint_names,
                                        num_latent_points)

        self.encoder_params: torch.nn.ParameterList = encoder_params
        self.decoder_params: torch.nn.ParameterList = decoder_params

        object.__setattr__(self, "encoder_weights", torch.stack([param for param in self.encoder_params], dim=-1))
        object.__setattr__(self, "decoder_weights", torch.stack([param for param in self.decoder_params], dim=-1).T)

    def forward(self, points):
        # d: euclidian dim
        # b: batch
        # i: input
        # o: output
        # l: latent

        # encoder_weights = object.__getattribute__(self, "encoder_weights")
        # decoder_weights = object.__getattribute__(self, "decoder_weights")

        encoder_weights = torch.stack([param for param in self.encoder_params], dim=-1)
        decoder_weights = torch.stack([param for param in self.decoder_params], dim=-1).T

        encoder = encoder_weights / encoder_weights.sum(1, keepdim=True)
        #
        decoder = decoder_weights / decoder_weights.sum(1, keepdim=True)

        latent_points = torch.einsum("b i d, l i -> b l d", points, encoder,
                                     )

        output_points = torch.einsum("b l d, o l -> b o d", latent_points, decoder,
                                     )

        return output_points, latent_points


class EndlessDataset(IterableDataset):
    def __iter__(self):
        while True:
            vertices = torch.tensor([
                [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0],  # Bottom vertices (z = 0)
                [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1]  # Top vertices (z = 1)
            ], dtype=torch.float32)

            offset_vertices = vertices  # + 100  # torch.rand(size=(1, 3), dtype=torch.float32) * 10
            rotation_matrix = torch.tensor(R.random().as_matrix(), dtype=torch.float32)
            rotated_vertices = torch.einsum('ij,kj->ki', rotation_matrix, offset_vertices)
            yield rotated_vertices  # Dummy value


if __name__ == '__main__':

    output_joint_names = ["l_1", "l_2", "l_3", "l_4", "r_1", "r_2", "r_3", "r_4"]

    dataloader = torch.utils.data.DataLoader(dataset=EndlessDataset(), batch_size=2)
    autoencoder = Autoencoder(output_joint_names=output_joint_names, num_latent_points=5)
    optimizer = torch.optim.Adam(params=autoencoder.parameters(), lr=0.001)

    for step, inputs in enumerate(dataloader):
        outputs, latents = autoencoder(inputs)

        loss = nn.MSELoss()
        optimizer.zero_grad()
        losses = loss(inputs, outputs)
        losses.backward()
        if step % 1000 == 0:
            print(f"step: {step}, losses: {losses}")

        optimizer.step()

        if step >= 5000:
            break

torch.stack is a differentiable operation and will thus create a non-leaf tensor. You would need to recreate a leaf tensor e.g. by wrapping it into nn.Parameter again to register it properly in the __init__.

My issue is that I don’t want to create a leaf tensor in the stack operation, I want the error to be propagated to the ParameterList in self.encoder_params. But when I set the tensor that I create with the stack operation as an attribute of the module, somehow it doesn’t work properly anymore.

So if I have a tensor that is created as a concatenation of multiple parameters do I have to rebuild it every iteration in order for the parameters to be learned correctly. The parameters can appear in the concatenated tensor multiple times (e.g. the same parameter column appears multiple times), so it is not possible to just take the concatenated tensor as a parameter, as in my understanding it dies not preserve the shared columns.

Yes, you would need to rebuild the tensor in the forward to allow the gradients to properly backpropagate to the registered parameters.

1 Like