Hello, I’m developed the following model, but when I tried to print model(x) I get the NotImplementedError (I think the indentation is ok). Did anyone meet this problem?
class MyVit(nn.Module):
def __init__(self, chw=(1, 28, 28), n_patches=7, hidden_d=8):
super(MyVit, self).__init__()
self.chw = chw
self.n_patches = n_patches
self.hidden_d = hidden_d
assert chw[1] % n_patches == 0, "Input shape not divisible by number of patches"
assert chw[2] % n_patches == 0, "Input shape not divisible by number of patches"
self.patches_size = (chw[1] / n_patches, chw[2] / n_patches)
self.input_d = int(chw[0] * self.patches_size[0] * self.patches_size[1])
self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)
self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
self.pos_embed = nn.Parameter(torch.tensor(get_positional_embeddings(self.n_patches ** 2
+ 1, self.hidden_d)))
self.pos_embed.requires_grad = False
def forward(self, images):
n, c, h, w = images.shape
patches = patchify(images, self.n_patches)
tokens = self.linear_mapper(patches)
tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
pos_embed = self.pos_embed.repeat(n, 1, 1)
out = tokens + pos_embed
return out
class MyMSA(nn.Module):
def __init__(self, dim, n_heads=2):
super(MyMSA, self).__init__()
self.dim = dim
self.n_heads = n_heads
assert dim % n_heads == 0, f"Can't divide dimension {dim} into {n_heads} heads"
d_heads = int(dim / n_heads)
self.q_map = nn.ModuleList([nn.Linear(d_heads, d_heads) for _ in range(self.n_heads)])
self.k_map = nn.ModuleList([nn.Linear(d_heads, d_heads) for _ in range(self.n_heads)])
self.v_map = nn.ModuleList([nn.Linear(d_heads, d_heads) for _ in range(self.n_heads)])
self.d_heads = d_heads
self.softmax = nn.Softmax(dim=-1)
def forward(self, sequences):
result = []
for sequence in sequences:
seq_result = []
for head in range(self.n_heads):
q_map = self.q_map(head)
k_map = self.k_map(head)
v_map = self.v_map(head)
seq = sequence[:, head * self.d_heads: (head + 1) * self.d_heads]
q, k, v = q_map(seq), k_map(seq), v_map(seq)
attention = self.softmax(q @ k.T / (self.d_heads ** 0.5))
seq_result.append(attention @ v)
result.append(torch.hstack(seq_result))
return torch.cat([torch.unsqueeze(r, dim=0) for r in result])
class MyVitBlock(nn.Module):
def __init__(self, hidden_d, n_heads, mlp_ratio=4):
super(MyVitBlock, self).__init__()
self.hidden_d = hidden_d
self.n_heads = n_heads
self.norm1 = nn.LayerNorm(hidden_d)
self.msa = MyMSA(hidden_d, n_heads)
self.norm2 = nn.LayerNorm(hidden_d)
self.mlp = nn.Sequential(
nn.Linear(hidden_d, mlp_ratio * hidden_d),
nn.GELU(),
nn.Linear(mlp_ratio * hidden_d, hidden_d)
)
def forward(self, x):
out = x + self.msa(self.norm1(x))
out = out + self.msa(self.norm2(x))
return out