So, i have tried to implement Vision Transformers from scratch, but somehow the layer to_qvk
and the W_0
keep initialized on None, even though theres not a single detached layer on this code. Can anyone figured out why? thankyou.
def expand_to_batch(tensor, desired_size):
tile = desired_size // tensor.shape[0]
return repeat(tensor, 'b ... -> (b tile) ...', tile=tile)
def compute_mhsa(q, k, v, scale_factor=1, mask=None):
scaled_dot_prod = torch.einsum('... i d , ... j d -> ... i j', q, k)
attention = torch.softmax(scaled_dot_prod, dim=-1)
return torch.einsum('... i j , ... j d -> ... i d', attention, v)
class MultiHeadSelfAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=None):
super().__init__()
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
_dim = self.dim_head * heads
self.heads = heads
self.to_qvk = nn.Linear(dim, _dim * 3, bias=False)
self.W_0 = nn.Linear(_dim, dim, bias=False)
self.rearrange = Rearrange("b h t d -> b t (h d)")
def forward(self, x, mask=None):
logger.info("============ MULTI HEAD SELF ATTENTION ============\n")
assert x.dim() == 3
qkv = self.to_qvk(x)
logger.info(f"Hasil QKV \n{qkv}")
q, k, v = nn.Parameter(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.heads))
logger.info(f"Hasil : Q \n{q}")
logger.info(f"Hasil : K \n{k}")
logger.info(f"Hasil : V \n{v}")
out = nn.Parameter(compute_mhsa(q, k, v, mask=mask))
logger.info(f"Hasil Scaled dot Product : \n{out}")
out = self.rearrange(out)
logger.info(f"Hasil Concat: \n{out}")
out = self.W_0(out)
logger.info(f"Hasil Linear : \n{out}")
return out
class TransformerBlock(nn.Module):
def __init__(self, dim, heads=8, dim_head=None,
dim_linear_block=1024, activation=nn.GELU,
mhsa=None, prenorm=False):
super().__init__()
self.mhsa = MultiHeadSelfAttention(dim=dim, heads=heads, dim_head=dim_head)
self.prenorm = False
self.norm_1 = nn.LayerNorm(dim)
self.gelu = nn.GELU()
def forward(self, x, mask=None):
logger.info(f"======================== Transformer Block ========================")
logger.info("Before MHSA")
y = self.mhsa(x)
logger.info(f"After MHSA \n{y}")
y = self.norm_1(x)
logger.info(f"After Normalization \n{y}")
out = self.gelu(y)
logger.info(f"After GeLU : \n{out}")
logger.info(f"Hasil Normalization \n {out}")
return out
class TransformerEncoder(nn.Module):
def __init__(self, dim, blocks=6, heads=8, dim_head=None, dim_linear_block=1024, prenorm=False):
super().__init__()
self.transformer_block = TransformerBlock(dim,heads,dim_head,dim_linear_block,prenorm=prenorm)
def forward(self, x, mask=None):
logger.info("======================== Transformer Encoder ========================")
x = self.transformer_block(x,mask)
logger.info(f"Hasil Forward Transformer Encoder {x}")
return x
class ViT(nn.Module):
def __init__(self, *,
img_dim = 4,
in_channels=3,
patch_dim=2,
num_classes=3,
dim=3,
blocks=1,
heads=3,
dim_linear_block=3,
dim_head=None,
transformer=None,
classification=None):
super().__init__()
assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible by img dim {img_dim}'
self.num_classes = num_classes
self.p = patch_dim
self.classification = classification
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(self.dim / heads)) if dim_head is None else dim_head
self.project_patches = nn.Linear(self.token_dim, self.dim,bias=False)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, self.dim))
if self.classification:
self.mlp_head = nn.Linear(self.dim, num_classes)
if transformer is None:
self.transformer = TransformerEncoder(self.dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
dim_linear_block=dim_linear_block)
else:
self.transformer = transformer
self.softmax = nn.Softmax()
def forward(self, img, mask=None):
logger.info("======================== Vision Transformer ========================")
logger.info(f"Image Shape : \n{img.shape}")
logger.info(f"Image Input : \n{img}")
img_patches = rearrange(img,
'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
patch_x=self.p, patch_y=self.p)
logger.info(f"Image Patches Shape : \n{img_patches.shape}")
logger.info(f"Image Patches : \n{img_patches}")
batch_size, tokens, _ = img_patches.shape
logger.info(f"Batch Size : \n {batch_size}")
logger.info(f"Tokens : \n{tokens}")
img_patches = self.project_patches(img_patches)
logger.info(f"project patches size : \n{img_patches.shape}")
logger.info(f"Project Patches : \n{img_patches}")
img_patches = torch.cat((expand_to_batch(self.cls_token, desired_size=batch_size), img_patches), dim=1)
logger.info(f"project patches after expand to token : \n{img_patches.shape}")
logger.info(f"Images Patches after matrix multiply (cat) : \n{img_patches}")
img_patches = img_patches + self.pos_emb1D[:tokens + 1, :]
logger.info(f"Image patches after Pos embed with tokens : \n{img_patches}")
y = self.transformer(img_patches, mask)
logger.info(f"Image after Transformer : \n{y}")
out = nn.Linear(self.dim, self.num_classes)
out = out(y[:,0,:])
logger.info(f"MLP Head : \n{out}")
return out
and heres the log info
for layer, weight in vit.named_parameters():
print(f"==================Layer {layer}==================================")
if "bias" in layer:
print(f"Bias {weight}")
print(f"Bias Gradient {weight.grad}")
elif "weight" in layer:
print(f" Weight {weight}")
print(f"Gradient {weight.grad}")
==================Layer cls_token==================================
Gradient tensor([[[-4.8729e-05, 1.3597e-05, 3.5132e-05]]])
==================Layer pos_emb1D==================================
Gradient tensor([[-4.8729e-05, 1.3597e-05, 3.5132e-05],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00]])
==================Layer project_patches.weight==================================
Weight Parameter containing:
tensor([[ 0.0490, 0.0490, 0.0490, -0.0480, -0.0480, -0.0480, 0.0490, -0.0490,
0.0450, -0.0490, 0.0480, -0.0490],
[ 0.0490, -0.0480, -0.0480, 0.0480, -0.0480, 0.0470, -0.0490, -0.0490,
-0.0480, 0.0490, -0.0450, -0.0470],
[ 0.0490, -0.0490, 0.0480, -0.0480, -0.0490, 0.0450, -0.0490, -0.0470,
0.0490, -0.0450, -0.0490, -0.0490]], requires_grad=True)
Gradient tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
==================Layer transformer.transformer_block.mhsa.to_qvk.weight==================================
Weight Parameter containing:
tensor([[ 0.0030, 0.0050, 0.0010],
[ 0.0010, -0.0010, -0.0010],
[-0.0030, -0.0030, -0.0030],
[ 0.0030, -0.0030, 0.0030],
[ 0.0040, 0.0040, 0.0040],
[-0.0010, 0.0010, 0.0040],
[-0.0020, -0.0020, -0.0020],
[ 0.0020, -0.0020, 0.0020],
[-0.0020, -0.0020, 0.0010]], requires_grad=True)
Gradient None
==================Layer transformer.transformer_block.mhsa.W_0.weight==================================
Weight Parameter containing:
tensor([[-0.0010, 0.0010, 0.0040],
[-0.0020, 0.0020, -0.0020],
[ 0.0020, -0.0010, -0.0010]], requires_grad=True)
Gradient None
==================Layer transformer.transformer_block.norm_1.weight==================================
Weight Parameter containing:
tensor([-0.0510, 0.0520, 0.0520], requires_grad=True)
Gradient tensor([ 0.1077, -0.1650, -0.2108])
==================Layer transformer.transformer_block.norm_1.bias==================================
Bias Parameter containing:
tensor([-0.0480, -0.0490, 0.0540], requires_grad=True)
Bias Gradient tensor([ 0.3079, 0.1212, -0.2083])
Gradient tensor([ 0.3079, 0.1212, -0.2083])