Thank for your reply!
I found the main problems in my code at the definition of the network.
I checked the gradients manually and they are near 0 (more or less 1e-11) .
My “fast” solution was to redefine the embedder and create a new joint model JointModel
and it worked perfectly.
class JointModel(nn.Module):
def __init__(self, net, classifier):
super(JointModel, self).__init__()
self.emb = net
self.clas = classifier
def forward(self, x1, x2):
B = images1.size(0)
X = torch.cat((x1, x2), dim=0)
X = self.emb(X)
x1, x2 = X[:B, ...], X[B:, ...]
return self.clas(x1, x2)
Now, my doubt now is why my last simple network had a problem with the gradients. (It is an AlexNet with few variants). I added the code next.
Note: I already varied the initial learning rate from 1e-1 up to 1e-4.
The network was initialized with the default parameters.
class AN(nn.Module):
def __init__(self, sobel="RGB", bn=False, init_weights=False, norm=False):
super(AN, self).__init__()
self.norm = norm
self.bn = bn
assert (sobel == "RGB") or (sobel == "Edges") or (sobel == "RGB+Edges")
self.edges = sobel
if sobel == "Edges" or sobel == "RGB+Edges":
grayscale = nn.Conv2d(3, 1, kernel_size=1, stride=1, padding=0)
grayscale.weight.data.fill_(1.0 / 3.0)
grayscale.bias.data.zero_()
sobel_filter = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1)
sobel_filter.weight.data[0, 0].copy_(
torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
)
sobel_filter.weight.data[1, 0].copy_(
torch.FloatTensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
)
sobel_filter.bias.data.zero_()
self.sobel = nn.Sequential(grayscale, sobel_filter)
for p in self.sobel.parameters():
p.requires_grad = False
if sobel == "Edges":
self.conv1 = nn.Conv2d(2, 64, kernel_size=11, padding=5, stride=4)
elif sobel == "RGB+Edges":
self.conv1 = nn.Conv2d(5, 64, kernel_size=11, padding=5, stride=4)
else:
self.conv1 = nn.Conv2d(3, 64, kernel_size=11, padding=5, stride=4)
self.relu1 = nn.ReLU(inplace=True)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2)
self.relu2 = nn.ReLU(inplace=True)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1)
self.relu3 = nn.ReLU(inplace=True)
self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
self.relu4 = nn.ReLU(inplace=True)
self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.relu5 = nn.ReLU(inplace=True)
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
if bn:
self.bn1 = nn.InstanceNorm2d(64)
self.bn2 = nn.InstanceNorm2d(192)
self.bn3 = nn.InstanceNorm2d(384)
self.bn4 = nn.InstanceNorm2d(256)
self.bn5 = nn.InstanceNorm2d(256)
if init_weights:
self._initialize_weights()
def forward(self, x):
if self.edges == "Edges":
x = self.sobel(x)
elif self.edges == "RGB+Edges":
x = torch.cat((self.sobel(x), x), dim=1)
x = self.conv1(x)
if self.bn:
x = self.bn1(x)
x = self.relu1(x)
x = self.maxpool1(x)
x = self.conv2(x)
if self.bn:
x = self.bn2(x)
x = self.relu2(x)
x = self.maxpool2(x)
x = self.conv3(x)
if self.bn:
x = self.bn3(x)
x = self.relu3(x)
x = self.conv4(x)
if self.bn:
x = self.bn4(x)
x = self.relu4(x)
x = self.conv5(x)
if self.bn:
x = self.bn5(x)
x = self.relu5(x)
x = self.maxpool3(x)
x = x.view(x.size(0), -1)
if self.norm:
x = F.normalize(x)
return x
def _initialize_weights(self):
print('Initiating network weights')
for _, m in enumerate(self.modules()):
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
for i in range(m.out_channels):
m.weight.data[i].normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
Thanks again for your reply 