I have ran into this multiple times now.
this does not work with class attributes.
import torch
from torch import nn as nn
from torch_utils import View, Λ
class Metric(nn.Module):
latent_dim = None
embed = None
kernel = None
# regular conv, migrated from ConvLargeL2.
class Conv(Metric):
def __init__(self, input_dim, latent_dim):
super().__init__()
self.latent_dim = latent_dim
self.embed = nn.Sequential(
nn.Conv2d(input_dim, 32, kernel_size=4, stride=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=4, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=4, stride=2),
nn.BatchNorm2d(32),
nn.ReLU(),
View(128),
nn.Linear(128, latent_dim),
)
self.kernel = Λ(lambda a, b: (a - b).norm(2, dim=-1))
def forward(self, x, x_prime):
*b, C, H, W = x.shape
*b_, C, H, W = x_prime.shape
z_1, z_2 = torch.broadcast_tensors(
self.embed(x.reshape(-1, C, H, W)).reshape(*b, self.latent_dim),
self.embed(x_prime.reshape(-1, C, H, W)).reshape(*b_, self.latent_dim))
*b, W = z_1.shape
return self.kernel(z_1, z_2).reshape(*b, 1)
main__:
net = Conv(...)
print(net.embed)
gives ‘None’.