Hey,
So I’ve just finished re-configuring a network. I replaced nn.Upsample
with the upConv
sequential container shown in the code below. I’ve verified that everything is lined up by running summary(UNetPP, (3, 128, 128))
which runs with no issue.
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class blockUNetPP(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.relu = nn.LeakyReLU(0.2, inplace=True)
self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(middle_channels)
self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
return out
class upConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.upc = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_ch, out_ch*2, 3, stride=1, padding=1),
nn.BatchNorm2d(out_ch*2),
nn.ReLU(inplace=True)
)
def forward(self, x):
out = self.upc(x)
return out
My issue is that when I try to start training the model I get the following issue:
Traceback (most recent call last):
File "runTrain.py", line 90, in <module>
netG.apply(weights_init)
File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 289, in apply
module.apply(fn)
File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 290, in apply
fn(self)
File "D:\Thesis Models\Deep_learning_models\UNet\train\NetC.py", line 8, in weights_init
m.weight.data.normal_(0.0, 0.02)
File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 594, in __getattr__
type(self).__name__, name))
AttributeError: 'upConv' object has no attribute 'weight'
I’ve looked up solutions which suggest looping over container modules, but I’m already doing this with weights_init(m)
. Could someone explain whats wrong with my current setup?