Hi,
I am trying to implement an qkv attention, and my attention module is like this:
class QKVAttention(nn.Module):
def __init__(self, n_chan_q=112, n_chan_kv=1280, mid_chan=256):
super(QKVAttention, self).__init__()
self.mid_chan = mid_chan
self.conv_q = ConvBlock(n_chan_q, mid_chan, 1, 1, 0)
self.conv_k = ConvBlock(n_chan_kv, mid_chan, 1, 1, 0)
self.conv_v = ConvBlock(n_chan_kv, mid_chan, 1, 1, 0)
def forward(self, feat, featq):
n, c, h, w = featq.size()
q = self.conv_q(featq).view(n, self.mid_chan, -1)
k = self.conv_k(feat).view(n, self.mid_chan, -1)
v = self.conv_v(feat).view(n, self.mid_chan, -1)
# qkv = q.permute(0, 2, 1).bmm(k).softmax(dim=2).bmm(v.permute(0, 2, 1))
# qkv = v.bmm(k.permute(0, 2, 1).bmm(q).softmax(dim=1))
qkv = v.bmm(k.transpose(1, 2).bmm(q).softmax(dim=1))
qkv = qkv.view(n, self.mid_chan, h, w)
return qkv
The ConvBlk
is a nn.Conv
followed by a nn.BatchNorm2d
and a nn.ReLU
, my problem is that when I run this code there is a warning message like this:
[W TensorIterator.cpp:918] Warning: Mixed memory format inputs detected while calling the operator. The operator will output contiguous tensor even if some of the inputs are in channels_last format. (function operator())
Would you tell me which operator broughts this warning, and what is wrong with my implementation ?