When I am replacing ImageLinearAttention
with SelfAttention
in Vision Transformer
, with the code as follows, I get a RuntimeError. The code for ImageLinearAttention
is from linear-attention-transformer/images.py at master · lucidrains/linear-attention-transformer · GitHub except I removed number of channels as you see in commented code.
class ImageLinearAttention(nn.Module):
def __init__(self, chan, chan_out = None, kernel_size = 1, padding = 0, stride = 1, key_dim = 64, value_dim = 64, heads = 8, norm_queries = True):
super().__init__()
self.chan = chan
chan_out = chan if chan_out is None else chan_out
self.key_dim = key_dim
self.value_dim = value_dim
self.heads = heads
self.norm_queries = norm_queries
conv_kwargs = {'padding': padding, 'stride': stride}
self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
print('value dim: ', value_dim)
print('chan out: ', chan_out)
print('kernel_size: ', kernel_size)
out_conv_kwargs = {'padding': padding}
print('out_conv_kwargs: ', out_conv_kwargs)
print('in_chan: ', value_dim * heads)
self.to_out = nn.Conv2d(value_dim * heads, chan_out, kernel_size, **out_conv_kwargs)
def forward(self, x, context = None):
print('x.shape: ', x.shape)
print('*x.shape is: ', *x.shape)
print('heads: ', self.heads)
#b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
b, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
q, k, v = map(lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
q, k = map(lambda x: x * (self.key_dim ** -0.25), (q, k))
if context is not None:
#context = context.reshape(b, c, 1, -1)
context = context.reshape(b, 1, -1)
ck, cv = self.to_k(context), self.to_v(context)
ck, cv = map(lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
k = torch.cat((k, ck), dim=3)
v = torch.cat((v, cv), dim=3)
k = k.softmax(dim=-1)
if self.norm_queries:
q = q.softmax(dim=-2)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhdn,bhde->bhen', q, context)
out = out.reshape(b, -1, h, w)
out = self.to_out(out)
return out
Error is:
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [384, 512, 1, 1], but got 3-dimensional input of size [1, 1984, 512] instead
Also, my data fed to transformer is of size torch.Size([1983, 512]) and my batch size is 1.
Full log is:
$ bash scripts/train.sh
train: True test: False cam: False
preparing datasets and dataloaders......
total_train_num: 176
creating models......
n_class: 2
in_dim: 512
value dim: 64
chan out: 512
kernel_size: 1
out_conv_kwargs: {'padding': 0}
in_chan: 768
in_dim: 512
value dim: 64
chan out: 512
kernel_size: 1
out_conv_kwargs: {'padding': 0}
in_chan: 768
=>Epoches 1, learning rate = 0.0010000, previous best = 0.0000
torch.Size([1983, 512])
features size: torch.Size([1983, 512])
/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:129: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:154: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
max_feature_num: 1983
batch feature size: torch.Size([1, 1983, 512])
x.shape: torch.Size([1, 1984, 512])
*x.shape is: 1 1984 512
heads: 12
Traceback (most recent call last):
File "main.py", line 148, in <module>
preds,labels,loss = trainer.train(sample_batched, model)
File "/SeaExp/mona/research/code/cc/helper.py", line 71, in train
pred,labels,loss = model.forward(feats, labels, masks)
File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 166, in forward
return self.module(*inputs[0], **kwargs[0])
File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/SeaExp/mona/research/code/cc/models/Transformer.py", line 31, in forward
out = self.transformer(X)
File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 262, in forward
feat = self.transformer(emb)
File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 206, in forward
out = layer(out)
File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 174, in forward
out = self.attn(out)
File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 92, in forward
q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 439, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [384, 512, 1, 1], but got 3-dimensional input of size [1, 1984, 512] instead
The original SelfAttention
code is:
class SelfAttention(nn.Module):
def __init__(self, in_dim, heads=8, dropout_rate=0.1):
super(SelfAttention, self).__init__()
self.heads = heads
self.head_dim = in_dim // heads
self.scale = self.head_dim ** 0.5
self.query = LinearGeneral((in_dim,), (self.heads, self.head_dim))
self.key = LinearGeneral((in_dim,), (self.heads, self.head_dim))
self.value = LinearGeneral((in_dim,), (self.heads, self.head_dim))
self.out = LinearGeneral((self.heads, self.head_dim), (in_dim,))
if dropout_rate > 0:
self.dropout = nn.Dropout(dropout_rate)
else:
self.dropout = None
def forward(self, x):
b, n, _ = x.shape
q = self.query(x, dims=([2], [0]))
k = self.key(x, dims=([2], [0]))
v = self.value(x, dims=([2], [0]))
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
attn_weights = F.softmax(attn_weights, dim=-1)
out = torch.matmul(attn_weights, v)
out = out.permute(0, 2, 1, 3)
out = self.out(out, dims=([2, 3], [0, 1]))
return out
How can I fix this error? I am calling the ImageSelfAttention
as following in the Encoder block of the Vision Transformer:
class EncoderBlock(nn.Module):
def __init__(self, in_dim, mlp_dim, num_heads, dropout_rate=0.1, attn_dropout_rate=0.1):
super(EncoderBlock, self).__init__()
self.norm1 = nn.LayerNorm(in_dim)
#self.attn = SelfAttention(in_dim, heads=num_heads, dropout_rate=attn_dropout_rate)
## note Mona: not sure if I am correctly passing the params
# what about attn_dropout_rate=0.1
## I don't know
print('in_dim: ', in_dim)
self.attn = ImageLinearAttention(chan=in_dim, heads=num_heads, key_dim=32)
if dropout_rate > 0:
self.dropout = nn.Dropout(dropout_rate)
else:
self.dropout = None
self.norm2 = nn.LayerNorm(in_dim)
self.mlp = MlpBlock(in_dim, mlp_dim, in_dim, dropout_rate)
def forward(self, x):
residual = x
out = self.norm1(x)
out = self.attn(out)
if self.dropout:
out = self.dropout(out)
out += residual
residual = out
out = self.norm2(out)
out = self.mlp(out)
out += residual
return out
The code for SelfAttention
and how to use it in encoder is mostly from vision-transformer-pytorch/model.py at main · asyml/vision-transformer-pytorch · GitHub