I want to train an autoencoder, where the encoder and decoder get created with some complex rules.
I was wondering why I cannot calculate the encoder matrix from stacked parameters in init, but instead need to do it in forward() (see commented lines that don’t work). Are the combined parameters somehow registered, or did I misunderstand something about how the computational graph is created?
Replacing the
encoder_weights = torch.stack([param for param in self.encoder_params], dim=-1)
decoder_weights = torch.stack([param for param in self.decoder_params], dim=-1).T
with
encoder_weights = object.__getattribute__(self, "encoder_weights")
decoder_weights = object.__getattribute__(self, "decoder_weights")
and I don’t understand why it breaks something.
Please tell me if I should create a reduced sample which reproduces my problem.
import torch
import torch.nn as nn
from scipy.spatial.transform import Rotation as R
from torch.utils.data import IterableDataset
def create_mapping(output_joint_names,
num_latent_points):
output_dim = len(output_joint_names)
all_parameters = [None] * output_dim
for i, output_joint_name in enumerate(output_joint_names):
param_free_latent_i = torch.nn.Parameter(torch.rand(size=(num_latent_points,)) / 10 + 0.5)
all_parameters[i] = param_free_latent_i
all_parameters = torch.nn.ParameterList(all_parameters)
return all_parameters
class Autoencoder(torch.nn.Module):
def __init__(self,
output_joint_names,
num_latent_points
):
super().__init__()
encoder_params = create_mapping(output_joint_names,
num_latent_points)
decoder_params = create_mapping(output_joint_names,
num_latent_points)
self.encoder_params: torch.nn.ParameterList = encoder_params
self.decoder_params: torch.nn.ParameterList = decoder_params
object.__setattr__(self, "encoder_weights", torch.stack([param for param in self.encoder_params], dim=-1))
object.__setattr__(self, "decoder_weights", torch.stack([param for param in self.decoder_params], dim=-1).T)
def forward(self, points):
# d: euclidian dim
# b: batch
# i: input
# o: output
# l: latent
# encoder_weights = object.__getattribute__(self, "encoder_weights")
# decoder_weights = object.__getattribute__(self, "decoder_weights")
encoder_weights = torch.stack([param for param in self.encoder_params], dim=-1)
decoder_weights = torch.stack([param for param in self.decoder_params], dim=-1).T
encoder = encoder_weights / encoder_weights.sum(1, keepdim=True)
#
decoder = decoder_weights / decoder_weights.sum(1, keepdim=True)
latent_points = torch.einsum("b i d, l i -> b l d", points, encoder,
)
output_points = torch.einsum("b l d, o l -> b o d", latent_points, decoder,
)
return output_points, latent_points
class EndlessDataset(IterableDataset):
def __iter__(self):
while True:
vertices = torch.tensor([
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], # Bottom vertices (z = 0)
[0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1] # Top vertices (z = 1)
], dtype=torch.float32)
offset_vertices = vertices # + 100 # torch.rand(size=(1, 3), dtype=torch.float32) * 10
rotation_matrix = torch.tensor(R.random().as_matrix(), dtype=torch.float32)
rotated_vertices = torch.einsum('ij,kj->ki', rotation_matrix, offset_vertices)
yield rotated_vertices # Dummy value
if __name__ == '__main__':
output_joint_names = ["l_1", "l_2", "l_3", "l_4", "r_1", "r_2", "r_3", "r_4"]
dataloader = torch.utils.data.DataLoader(dataset=EndlessDataset(), batch_size=2)
autoencoder = Autoencoder(output_joint_names=output_joint_names, num_latent_points=5)
optimizer = torch.optim.Adam(params=autoencoder.parameters(), lr=0.001)
for step, inputs in enumerate(dataloader):
outputs, latents = autoencoder(inputs)
loss = nn.MSELoss()
optimizer.zero_grad()
losses = loss(inputs, outputs)
losses.backward()
if step % 1000 == 0:
print(f"step: {step}, losses: {losses}")
optimizer.step()
if step >= 5000:
break