def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.discriminator = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=5, stride=1, padding=2),
nn.ReLU(True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=2),
nn.AvgPool2d(kernel_size=2, stride=1, padding=0),
nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0),
nn.Sigmoid()
)
def forward(self, input_feature, alpha):
x = self.discriminator(input_feature)
return x
network = Network().cuda()
...
network = network.train()
network.apply(weights_init)
When I apply weights_init()
AttributeError: 'Network' object has no attribute 'weight'
error occurs.`
What is the wrong in my usage?