You’re right, model
is an instance of another class, and Transnet
is just one component that goes into it. Here’s the full code defining the classes, setting up training functions, and running it:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import random as rand
#Define classes
class Disentangler(nn.Module):
def __init__(self,encoder,decoder, transnet):
super(Disentangler,self).__init__()
self.encoder = encoder
self.decoder = decoder
self.transnet = transnet #estimates trans parameters, contains exponential weights, creates matrices
def forward(self,x, x0=None):
if x0 == None:
y = self.encoder(x)
s = torch.zeros(x.size(0), self.encoder.latent_dim)
else:
y, s = self.transnet(x,x0)
z = self.decoder(y)
return z,y,s
class Encoder(nn.Module):
def __init__(self, og_dim, latent_dim): #if images are nXn, og_dim = n^2.
assert latent_dim <= og_dim, 'latent space must have lower dimension'
super(Encoder,self).__init__()
self.og_dim = og_dim
self.latent_dim = latent_dim
self.fc1 = nn.Linear(og_dim, max(latent_dim, og_dim//16))
self.fc2 = nn.Linear(max(latent_dim, og_dim//16), latent_dim)
def forward(self,x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return x
class Decoder(nn.Module):
def __init__(self, og_dim, latent_dim):
assert latent_dim <= og_dim, 'latent space must have lower dimension'
super(Decoder,self).__init__()
self.og_dim = og_dim
self.latent_dim = latent_dim
self.fc1 = nn.Linear(latent_dim, max(latent_dim, og_dim//16))
self.fc2 = nn.Linear(max(latent_dim, og_dim//16), og_dim)
def forward(self,x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return x
class Transnet(nn.Module):
def __init__(self, og_dim, latent_dim, trans_dim, k_sparse):
super(Transnet,self).__init__()
assert latent_dim <= og_dim, 'latent space must have lower dimension'
assert trans_dim <= latent_dim, 'translation dimension must be subspace'
self.og_dim = og_dim
self.latent_dim = latent_dim
self.trans_dim = trans_dim
ttl_dim = og_dim + latent_dim
self.ttl_dim = ttl_dim
self.k_sparse = k_sparse
self.fc1 = nn.Linear(ttl_dim, max(latent_dim, ttl_dim//16))
self.fc2 = nn.Linear(max(latent_dim, ttl_dim//16), max(latent_dim, ttl_dim//32))
self.fc3 = nn.Linear(max(latent_dim, ttl_dim//32), trans_dim)
def forward(self,x,x0):
x1 = torch.cat((x,x0),dim = 1) #create (B, N+M) tensor
x1 = self.fc1(x1)
x1 = F.relu(x1)
x1 = self.fc2(x1)
x1 = F.relu(x1)
x1 = self.fc3(x1)
x0[:,:self.trans_dim] += x1
return x0, x1
def make_model(og_dim, latent_dim, trans_dim, k_sparse=1):
enc = Encoder(og_dim, latent_dim)
dec = Decoder(og_dim, latent_dim)
trans = Transnet(og_dim, latent_dim, trans_dim, k_sparse)
model = Disentangler(enc,dec,trans)
return model
#Training procedure
def train(print_interval, model, device, train_loader, optimizer, epoch, movie_len, transform_set, beta = .7):
model.train()
for epoch in range(epoch):
for batch_idx, (data, target) in enumerate(train_loader):
loss = 0
optimizer.zero_grad()
for i in range(movie_len):
if i == 0:
transform = rand.choice(transform_set)
prev_frame = curr_frame = data
curr_frame = curr_frame.flatten(1).to(device)
output, latent_rep, trans_par = model(curr_frame)
latent_rep = latent_rep.detach().clone()
latent_rep = latent_rep.to(device)
else:
curr_frame = transform(prev_frame)
prev_frame = curr_frame
curr_frame = curr_frame.flatten(1).to(device)
output, out_rep, trans_par = model(curr_frame, latent_rep)
loss += (beta**i)*(F.mse_loss(output, curr_frame) + 5e-2*(1/50)*torch.norm(trans_par,1))
loss.backward()
optimizer.step()
#List of simple transforms to be applied
hor_trans = transforms.Compose(
[transforms.RandomAffine(0, translate = (.1,0)),
transforms.Normalize(.3,.3)])
ver_trans = transforms.Compose(
[transforms.RandomAffine(0,translate = (0,.1)),
transforms.Normalize(.3,.3)])
transform_set = [hor_trans,ver_trans]
#Create data loader
loader_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(.3,.3)]
)
batch_size = 50
data_set = datasets.MNIST(root='./data', train=True, download=False, transform=loader_transform)
data_loader = torch.utils.data.DataLoader(data_set, batch_size = batch_size, shuffle = True)
#Generate model and train
model_dis = make_model(28**2, latent_dim= 16, trans_dim = 2)
device = torch.device('cuda')
torch.cuda.set_device('cuda')
model_dis = model_dis.to(device)
optimizer = torch.optim.Adam(model_dis.parameters(), lr=0.001)
train(200, model_dis, device, data_loader, optimizer, epoch = 10, movie_len = 3, transform_set = transform_set)
This should generate the error. I’ve tried to simplify it without changing how it works fundamentally, but sorry if its alot of code.