class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Arguments:
x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
"""
x = x.permute(1,0,2)
x = x + self.pe[:x.size(0)]
return self.dropout(x).permute(1,0,2)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=8, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
self.pos = PositionalEncoding(256, max_len=1000) # sequence length
self.cls1 = nn.Parameter(torch.rand(1, 256)) # 1, enc_dim
self.cls2 = nn.Parameter(torch.rand(1, 256))
def forward(self, x):
x = self.pos(x)
M = ... # defined on some logic and need to have dimension as (4, 9, 9) as stated below
x = self.transformer_encoder(x, mask=M.logical_not(), src_key_padding_mask=None)#pad.logical_not())
return x
Say batch size is 4.
For some reason I am trying to apply different mask for each input, i.e 4 mask for 4 input.
One such example of the mask looks like (assume sequence length as 9, so 9 by 9 matrix for mask)(some kind of restrictive self attention)
tensor([[1., 1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 1., 1., 0., 0.],
[0., 0., 0., 0., 1., 1., 1., 0., 0.],
[0., 0., 0., 0., 1., 1., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]])
When I try to pass 4 such different mask of dimension (4, 9,9), I get error as
RuntimeError: The shape of the 3D attn_mask is torch.Size([4, 9, 9]), but should be (32, 9, 9).
What should I do to achieve what I am expecting.