class ResidAttention(nn.Module):
def __init__(self):
super(ResidAttention, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16,
kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True))
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.Resblock1 = ResUnit(16, 64)
self.attention_module1 = AttentionModule_stage1(64, 64, size1=(128,128), size2=(64,64), size3=(32,32))
self.attention_module2 = AttentionModule_stage1(64, 64, size1=(128,128), size2=(64,64), size3=(32,32))
self.Resblock2 = ResUnit(64, 128, 2)
self.attention_module3 = AttentionModule_stage2(128, 128, size1=(64,64), size2=(32,32))
self.attention_module4 = AttentionModule_stage2(128, 128, size1=(64,64), size2=(32,32))
self.Resblock3 = ResUnit(128, 256, 2)
self.attention_module5 = AttentionModule_stage3(256, 256, size1=(32,32))
self.attention_module6 = AttentionModule_stage3(256, 256, size1=(32,32))
self.Resblock4 = nn.Sequential(ResUnit(256, 512, 2),
ResUnit(512, 512),
ResUnit(512, 512))
self.Avergepool = nn.Sequential(
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=6, stride=1)
)
self.fc1 = nn.Linear(61952, 1)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.Resblock1(x)
x = self.attention_module1(x)
x = self.attention_module2(x)
x = self.Resblock2(x)
x = self.attention_module3(x)
x = self.attention_module4(x)
x = self.Resblock3(x)
x = self.attention_module5(x)
x = self.attention_module6(x)
x = self.Resblock4(x)
x = self.Avergepool(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
self.encoder = ResidAttention()
...
def encode(self, *img):
for x in img:
print ("x.shape:" + str(x.shape))
#torch.Size([8, 64, 64])
img = x[ :, None, :, :]
print ("img.shape after more dim:" + str(img.shape))
#torch.Size([8, 1, 64, 64])
z = self.encoder(*img)
Here img is nothing but 64x64 pixel values in 8 batches (8 tif image files).
The error message is
z = self.encoder(*img)
File "/people/kimd999/bin/Miniconda3-latest-Linux-x86_64/envs/cryodrgn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 9 were given