I am trying to implement self attention network and here is the code
class FANet(nn.Module):
def __init__(self,ipchannels):
super(FANet,self).__init__()
self.conv1=nn.Conv2d(in_channels=ipchannels,out_channels=128,kernel_size=1,stride=1)
self.bn1=nn.BatchNorm2d(128)
self.conv2=nn.Conv2d(in_channels=128,out_channels=512,kernel_size=1,stride=1)
self.bn2=nn.BatchNorm2d(512)
self.relu=nn.ReLU()
def forward(self,x):
bs=x.size(0)
h=x.size(2)
q=self.conv1(x)
q=self.bn1(q)
k=self.bn1(self.conv1(x))
v=self.bn1(self.conv1(x))
q=F.normalize(q.view(128,-1),p=2,dim=0)
k=F.normalize(k.view(128,-1),p=2,dim=0)
v=self.relu(v.view(-1,128))
result1=torch.mm(v,k)
result2=torch.mm(q,result1)
print(result2.shape)
finalresult=self.bn2(self.conv2(result2.view(bs,128,h,-1)))
print(finalresult.size())
return finalresult
And my RAM always crashes when I try to pass an image tensor through the FANet()