i am trying to run torchsummary but getting this error. I don’t understand why, please help me. This is my code:
from gc_layer import GatedConv2d, GatedDeConv2d
import torch
import torch.nn as nn
from base_model import BaseModel
class Discriminator(BaseModel):
def __init__(self, channels = 64):
super(Discriminator, self).__init__()
self.channels = channels
input_dim = 3
self.init_weights()
self.gt_conv1 = GatedConv2d(input_dim+1, channels, kernel_size=3, dilation=1, padding='same')
self.lk1 = nn.LeakyReLU()
layers = []
for i in range(1,7):
mult = (2**i) if (2**i) < 8 else 8
in_mult = (2**i) if (2**i) < 9 else 16
layers.append(nn.Sequential(GatedConv2d(channels*in_mult, channels*mult, kernel_size=3, stride=2, padding='same', dilation=1),
nn.LeakyReLU()))
self.layers = nn.Sequential(*layers)
def forward(self, inputs, mask):
x_in = torch.cat([inputs, mask], dim=1)
output = self.gt_conv1(x_in)
output = self.lk1(output)
for i in range(1,7):
output = self.layers[i](output)
return output
if __name__ == "__main__":
model = Discriminator()
print(model)
from torchsummary import summary
print(summary(model, [(3,256,256), (1,256,256)], 1))```
And this is error:
Traceback (most recent call last):
File “/home/huynth/miniconda3/envs/inpainting/lib/python3.8/site-packages/torchsummary/torchsummary.py”, line 140, in summary
_ = model.to(device)(*x, *args, **kwargs) # type: ignore[misc]
File “/home/huynth/miniconda3/envs/inpainting/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
TypeError: forward() takes 3 positional arguments but 4 were given
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File “/home/huynth/Hypergraph-Inpainting/models/discriminator.py”, line 38, in
print(summary(model, [(3,256,256), (1,256,256)], 1))
File “/home/huynth/miniconda3/envs/inpainting/lib/python3.8/site-packages/torchsummary/torchsummary.py”, line 143, in summary
raise RuntimeError(
RuntimeError: Failed to run torchsummary. See above stack traces for more details. Executed layers up to: []