Hi all, I’m new to pytorch.
I try to implement Resnet currently, but I have a problem about nn.sequential
class PlainBlock(nn.Module):
def __init__(self, Cin, Cout, downsample=False):
super().__init__()
self.net = None
padding_size1 = 1
kernel_size1 = 3
padding_size2 = 1
kernel_size2 = 3
self.net = nn.Sequential(
('Spatial Batch normalization1', nn.BatchNorm2d(Cin)),
('ReLU1', nn.ReLU()),
('Conv1', nn.Conv2d(Cin, Cout, kernel_size = kernel_size1, stride = stride_1, padding = padding_size1, bias = False)),
('Spatial Batch normalization2', nn.BatchNorm2d(Cout)),
('ReLU2', nn.ReLU()),
('Conv2', nn.Conv2d(Cout, Cout, kernel_size = kernel_size2, stride = 1, padding = padding_size2, bias = False))
)
############################################################################
def forward(self, x):
return self.net(x)
this is the error i got
TypeError: tuple is not a Module subclass
Here is the traceback
TypeError Traceback (most recent call last)
<ipython-input-32-ccdd2c3e2e78> in <module>()
1 data = torch.zeros(2, 3, 5, 6)
2 # YOUR_TURN: Impelement PlainBlock.__init__
----> 3 model = PlainBlock(3, 10)
4 if list(model(data).shape) == [2, 10, 5, 6]:
5 print('The output of PlainBlock without downsampling has a *correct* dimension!')
2 frames
/content/drive/My Drive/HW9/pytorch_autograd_and_nn.py in __init__(self, Cin, Cout, downsample)
357 ('Spatial Batch normalization2', nn.BatchNorm2d(Cout)),
358 ('ReLU2', nn.ReLU()),
--> 359 ('Conv2', nn.Conv2d(Cout, Cout, kernel_size = kernel_size2, stride = 1, padding = padding_size2, bias = False))
360 )
361
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/container.py in __init__(self, *args)
89 else:
90 for idx, module in enumerate(args):
---> 91 self.add_module(str(idx), module)
92
93 def _get_item_by_idx(self, iterator, idx) -> T:
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in add_module(self, name, module)
376 if not isinstance(module, Module) and module is not None:
377 raise TypeError("{} is not a Module subclass".format(
--> 378 torch.typename(module)))
379 elif not isinstance(name, torch._six.string_classes):
380 raise TypeError("module name should be a string. Got {}".format(
Hope someone can help
Sincerely!