Parent Struct Attribute Always Overrides Child Module Class

I have ran into this multiple times now.

this does not work with class attributes.

import torch
from torch import nn as nn
from torch_utils import View, Λ


class Metric(nn.Module):
    latent_dim = None
    embed = None
    kernel = None


# regular conv, migrated from ConvLargeL2.
class Conv(Metric):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        self.embed = nn.Sequential(
            nn.Conv2d(input_dim, 32, kernel_size=4, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=4, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=4, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            View(128),
            nn.Linear(128, latent_dim),
        )
        self.kernel = Λ(lambda a, b: (a - b).norm(2, dim=-1))

    def forward(self, x, x_prime):
        *b, C, H, W = x.shape
        *b_, C, H, W = x_prime.shape
        z_1, z_2 = torch.broadcast_tensors(
            self.embed(x.reshape(-1, C, H, W)).reshape(*b, self.latent_dim),
            self.embed(x_prime.reshape(-1, C, H, W)).reshape(*b_, self.latent_dim))
        *b, W = z_1.shape
        return self.kernel(z_1, z_2).reshape(*b, 1)

main__:
   net = Conv(...)
   print(net.embed)

gives ‘None’.

Try to initialize Metric as a module via:

class Metric(nn.Module):
    def __init__(self):
        super(Metric, self).__init__()
        latent_dim = None
        embed = None
        kernel = None

and it should work.