You’re right, `model`

is an instance of another class, and `Transnet`

is just one component that goes into it. Here’s the full code defining the classes, setting up training functions, and running it:

```
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import random as rand
#Define classes
class Disentangler(nn.Module):
def __init__(self,encoder,decoder, transnet):
super(Disentangler,self).__init__()
self.encoder = encoder
self.decoder = decoder
self.transnet = transnet #estimates trans parameters, contains exponential weights, creates matrices
def forward(self,x, x0=None):
if x0 == None:
y = self.encoder(x)
s = torch.zeros(x.size(0), self.encoder.latent_dim)
else:
y, s = self.transnet(x,x0)
z = self.decoder(y)
return z,y,s
class Encoder(nn.Module):
def __init__(self, og_dim, latent_dim): #if images are nXn, og_dim = n^2.
assert latent_dim <= og_dim, 'latent space must have lower dimension'
super(Encoder,self).__init__()
self.og_dim = og_dim
self.latent_dim = latent_dim
self.fc1 = nn.Linear(og_dim, max(latent_dim, og_dim//16))
self.fc2 = nn.Linear(max(latent_dim, og_dim//16), latent_dim)
def forward(self,x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return x
class Decoder(nn.Module):
def __init__(self, og_dim, latent_dim):
assert latent_dim <= og_dim, 'latent space must have lower dimension'
super(Decoder,self).__init__()
self.og_dim = og_dim
self.latent_dim = latent_dim
self.fc1 = nn.Linear(latent_dim, max(latent_dim, og_dim//16))
self.fc2 = nn.Linear(max(latent_dim, og_dim//16), og_dim)
def forward(self,x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return x
class Transnet(nn.Module):
def __init__(self, og_dim, latent_dim, trans_dim, k_sparse):
super(Transnet,self).__init__()
assert latent_dim <= og_dim, 'latent space must have lower dimension'
assert trans_dim <= latent_dim, 'translation dimension must be subspace'
self.og_dim = og_dim
self.latent_dim = latent_dim
self.trans_dim = trans_dim
ttl_dim = og_dim + latent_dim
self.ttl_dim = ttl_dim
self.k_sparse = k_sparse
self.fc1 = nn.Linear(ttl_dim, max(latent_dim, ttl_dim//16))
self.fc2 = nn.Linear(max(latent_dim, ttl_dim//16), max(latent_dim, ttl_dim//32))
self.fc3 = nn.Linear(max(latent_dim, ttl_dim//32), trans_dim)
def forward(self,x,x0):
x1 = torch.cat((x,x0),dim = 1) #create (B, N+M) tensor
x1 = self.fc1(x1)
x1 = F.relu(x1)
x1 = self.fc2(x1)
x1 = F.relu(x1)
x1 = self.fc3(x1)
x0[:,:self.trans_dim] += x1
return x0, x1
def make_model(og_dim, latent_dim, trans_dim, k_sparse=1):
enc = Encoder(og_dim, latent_dim)
dec = Decoder(og_dim, latent_dim)
trans = Transnet(og_dim, latent_dim, trans_dim, k_sparse)
model = Disentangler(enc,dec,trans)
return model
#Training procedure
def train(print_interval, model, device, train_loader, optimizer, epoch, movie_len, transform_set, beta = .7):
model.train()
for epoch in range(epoch):
for batch_idx, (data, target) in enumerate(train_loader):
loss = 0
optimizer.zero_grad()
for i in range(movie_len):
if i == 0:
transform = rand.choice(transform_set)
prev_frame = curr_frame = data
curr_frame = curr_frame.flatten(1).to(device)
output, latent_rep, trans_par = model(curr_frame)
latent_rep = latent_rep.detach().clone()
latent_rep = latent_rep.to(device)
else:
curr_frame = transform(prev_frame)
prev_frame = curr_frame
curr_frame = curr_frame.flatten(1).to(device)
output, out_rep, trans_par = model(curr_frame, latent_rep)
loss += (beta**i)*(F.mse_loss(output, curr_frame) + 5e-2*(1/50)*torch.norm(trans_par,1))
loss.backward()
optimizer.step()
#List of simple transforms to be applied
hor_trans = transforms.Compose(
[transforms.RandomAffine(0, translate = (.1,0)),
transforms.Normalize(.3,.3)])
ver_trans = transforms.Compose(
[transforms.RandomAffine(0,translate = (0,.1)),
transforms.Normalize(.3,.3)])
transform_set = [hor_trans,ver_trans]
#Create data loader
loader_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(.3,.3)]
)
batch_size = 50
data_set = datasets.MNIST(root='./data', train=True, download=False, transform=loader_transform)
data_loader = torch.utils.data.DataLoader(data_set, batch_size = batch_size, shuffle = True)
#Generate model and train
model_dis = make_model(28**2, latent_dim= 16, trans_dim = 2)
device = torch.device('cuda')
torch.cuda.set_device('cuda')
model_dis = model_dis.to(device)
optimizer = torch.optim.Adam(model_dis.parameters(), lr=0.001)
train(200, model_dis, device, data_loader, optimizer, epoch = 10, movie_len = 3, transform_set = transform_set)
```

This should generate the error. I’ve tried to simplify it without changing how it works fundamentally, but sorry if its alot of code.