I have a network, in which there are 3 architectures that share the same classifier.
class VGGBlock(nn.Module):
def __init__(self, in_channels, out_channels,batch_norm=False):
super(VGGBlock,self).__init__()
conv2_params = {'kernel_size': (3, 3),
'stride' : (1, 1),
'padding' : 1
}
noop = lambda x : x
self._batch_norm = batch_norm
self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels , **conv2_params)
self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop
self.conv2 = nn.Conv2d(in_channels=out_channels,out_channels=out_channels, **conv2_params)
self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop
self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
@property
def batch_norm(self):
return self._batch_norm
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.max_pooling(x)
return x
class VGG16(nn.Module):
def __init__(self, input_size, num_classes=1,batch_norm=False):
super(VGG16, self).__init__()
self.in_channels,self.in_width,self.in_height = input_size
self.block_1 = VGGBlock(self.in_channels,64,batch_norm=batch_norm)
self.block_2 = VGGBlock(64, 128,batch_norm=batch_norm)
self.block_3 = VGGBlock(128, 256,batch_norm=batch_norm)
self.block_4 = VGGBlock(256,512,batch_norm=batch_norm)
@property
def input_size(self):
return self.in_channels,self.in_width,self.in_height
def forward(self, x):
x = self.block_1(x)
x = self.block_2(x)
x = self.block_3(x)
x = self.block_4(x)
return x
class VGG16Classifier(nn.Module):
def __init__(self, num_classes=1,classifier = None,batch_norm=False):
super(VGG16Classifier, self).__init__()
self._vgg_a = VGG16((1,32,32),batch_norm=True)
self._vgg_b = VGG16((1,32,32),batch_norm=True)
self._vgg_star = VGG16((1,32,32),batch_norm=True)
self.classifier = classifier
if (self.classifier is None):
self.classifier = nn.Sequential(
nn.Linear(2048, 2048),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(2048, 512),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(512, num_classes)
)
def forward(self, x1,x2,x3):
op1 = self._vgg_a(x1)
op1 = torch.flatten(op1,1)
op2 = self._vgg_b(x2)
op2 = torch.flatten(op2,1)
op3 = self._vgg_star(x3)
op3 = torch.flatten(op3,1)
x1 = self.classifier(op1)
x2 = self.classifier(op2)
x3 = self.classifier(op3)
return x1,x2,x3
model1 = VGG16((1,32,32),batch_norm=True)
model2 = VGG16((1,32,32),batch_norm=True)
model_star = VGG16((1,32,32),batch_norm=True)
model_combo = VGG16Classifier(model1,model2,model_star)
I want to traing model_combo using the following loss function:
class CombinedLoss(nn.Module):
def __init__(self, loss_a, loss_b, loss_star, _lambda=1.0):
super().__init__()
self.loss_a = loss_a
self.loss_b = loss_b
self.loss_star = loss_star
self.register_buffer('_lambda',torch.tensor(float(_lambda),dtype=torch.float32))
def forward(self,y_hat,y):
return (self.loss_a(y_hat[0],y[0]) +
self.loss_b(y_hat[1],y[1]) +
self.loss_combo(y_hat[2],y[2]) +
self._lambda * torch.sum(model_star.weight - torch.pow(torch.cdist(model1.weight+model2.weight), 2)))
In the training function I pass loaders, that for simplicity are loaders_a, loaders_b and again loaders_a, where loaders_a is related to the first 50% of data of MNIST and loaders_b to the latter 50% of MNIST.
def train(net, loaders, optimizer, criterion, epochs=20, dev=None, save_param=False, model_name="valerio"):
loaders_a, loaders_b, loaders_star = loaders
# try:
net = net.to(dev)
#print(net)
#summary(net,[(net.in_channels,net.in_width,net.in_height)]*2)
criterion.to(dev)
# Initialize history
history_loss = {"train": [], "val": [], "test": []}
history_accuracy_a = {"train": [], "val": [], "test": []}
history_accuracy_b = {"train": [], "val": [], "test": []}
history_accuracy_star = {"train": [], "val": [], "test": []}
# Store the best val accuracy
best_val_accuracy = 0
# Process each epoch
for epoch in range(epochs):
# Initialize epoch variables
sum_loss = {"train": 0, "val": 0, "test": 0}
sum_accuracy_a = {"train": 0, "val": 0, "test": 0}
sum_accuracy_b = {"train": 0, "val": 0, "test": 0}
sum_accuracy_star = {"train": 0, "val": 0, "test": 0}
progbar = None
# Process each split
for split in ["train", "val", "test"]:
if split == "train":
net.train()
#widgets = [
#' [', pb.Timer(), '] ',
#pb.Bar(),
#' [', pb.ETA(), '] ', pb.Variable('ta','[Train Acc: {formatted_value}]')]
#progbar = pb.ProgressBar(max_value=len(loaders_a[split]),widgets=widgets,redirect_stdout=True)
else:
net.eval()
# Process each batch
for j, ((input_a, labels_a), (input_b, labels_b), (input_s, labels_s)) in enumerate(zip(loaders_a[split], loaders_b[split], loaders_star[split])):
labels_a = labels_a.unsqueeze(1).float()
labels_b = labels_b.unsqueeze(1).float()
labels_s = labels_s.unsqueeze(1).float()
input_a = input_a.to(dev)
labels_a = labels_a.to(dev)
input_b = input_b.to(dev)
labels_b = labels_b.to(dev)
input_s = input_s.to(dev)
labels_s = labels_s.to(dev)
# Reset gradients
optimizer.zero_grad()
# Compute output
pred = net(input_a,input_b, input_s)
loss = criterion(pred, [labels_a, labels_b, labels_s])
# Update loss
sum_loss[split] += loss.item()
# Check parameter update
if split == "train":
# Compute gradients
loss.backward()
# Optimize
optimizer.step()
# Compute accuracy
pred_labels = (pred[2] >= 0.0).long() # Binarize predictions to 0 and 1
pred_labels_a = (pred[0] >= 0.0).long() # Binarize predictions to 0 and 1
pred_labels_b = (pred[1] >= 0.0).long() # Binarize predictions to 0 and 1
batch_accuracy_star = (pred_labels == labels_s).sum().item() / len(labels_s)
batch_accuracy_a = (pred_labels_a == labels_a).sum().item() / len(labels_a)
batch_accuracy_b = (pred_labels_b == labels_b).sum().item() / len(labels_b)
# Update accuracy
sum_accuracy_star[split] += batch_accuracy_star
sum_accuracy_a[split] += batch_accuracy_a
sum_accuracy_b[split] += batch_accuracy_b
#if (split=='train'):
#progbar.update(j, ta=batch_accuracy)
#progbar.update(j, ta=batch_accuracy_a)
#progbar.update(j, ta=batch_accuracy_b)
#if (progbar is not None):
#progbar.finish()
# Compute epoch loss/accuracy
#for split in ["train", "val", "test"]:
#epoch_loss = sum_loss[split] / (len(loaders_a[split])+len(loaders_b[split]))
#epoch_accuracy_combo = {split: sum_accuracy_combo[split] / len(loaders[split]) for split in ["train", "val", "test"]}
#epoch_accuracy_a = sum_accuracy_a[split] / len(loaders_a[split])
#epoch_accuracy_b = sum_accuracy_b[split] / len(loaders_b[split])
epoch_loss = sum_loss["train"] / (len(loaders_a["train"])+len(loaders_b["train"])+len(loaders_s["train"]))
epoch_accuracy_a = sum_accuracy_a["train"] / len(loaders_a["train"])
epoch_accuracy_b = sum_accuracy_b["train"] / len(loaders_b["train"])
epoch_accuracy_star = sum_accuracy_star["train"] / len(loaders_s["train"])
epoch_loss_val = sum_loss["val"] / (len(loaders_a["val"])+len(loaders_b["val"])+len(loaders_s["val"]))
epoch_accuracy_a_val = sum_accuracy_a["val"] / len(loaders_a["val"])
epoch_accuracy_b_val = sum_accuracy_b["val"] / len(loaders_b["val"])
epoch_accuracy_star_val = sum_accuracy_star["val"] / len(loaders_s["val"])
epoch_loss_test = sum_loss["test"] / (len(loaders_a["test"])+len(loaders_b["test"])+len(loaders_s["test"]))
epoch_accuracy_a_test = sum_accuracy_a["test"] / len(loaders_a["test"])
epoch_accuracy_b_test = sum_accuracy_b["test"] / len(loaders_b["test"])
epoch_accuracy_star_test = sum_accuracy_star["test"] / len(loaders_s["test"])
# Store params at the best validation accuracy
if save_param and epoch_accuracy["val"] > best_val_accuracy:
# torch.save(net.state_dict(), f"{net.__class__.__name__}_best_val.pth")
torch.save(net.state_dict(), f"{model_name}_best_val.pth")
best_val_accuracy = epoch_accuracy["val"]
# Update history
for split in ["train", "val", "test"]:
history_loss[split].append(epoch_loss)
history_accuracy_a[split].append(epoch_accuracy_a)
history_accuracy_b[split].append(epoch_accuracy_b)
history_accuracy_star[split].append(epoch_accuracy_star)
# Print info
print(f"Epoch {epoch + 1}:",
f"Training Loss = {epoch_loss:.4f},",)
print(f"Epoch {epoch + 1}:",
f"Training Accuracy for A = {epoch_accuracy_a:.4f},")
print(f"Epoch {epoch + 1}:",
f"Training Accuracy for B = {epoch_accuracy_b:.4f},")
print(f"Epoch {epoch + 1}:",
f"Training Accuracy for star = {epoch_accuracy_star:.4f},")
print(f"Epoch {epoch + 1}:",
f"Val Loss = {epoch_loss_val:.4f},",)
print(f"Epoch {epoch + 1}:",
f"Val Accuracy for A = {epoch_accuracy_a_val:.4f},")
print(f"Epoch {epoch + 1}:",
f"Val Accuracy for B = {epoch_accuracy_b_val:.4f},")
print(f"Epoch {epoch + 1}:",
f"Val Accuracy for star = {epoch_accuracy_star_val:.4f},")
print(f"Epoch {epoch + 1}:",
f"Test Loss = {epoch_loss_test:.4f},",)
print(f"Epoch {epoch + 1}:",
f"Test Accuracy for A = {epoch_accuracy_a_test:.4f},")
print(f"Epoch {epoch + 1}:",
f"Test Accuracy for B = {epoch_accuracy_b_test:.4f},")
print(f"Epoch {epoch + 1}:",
f"Test Accuracy for star = {epoch_accuracy_star_test:.4f},")
print("\n")
But I got this error:
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 1, 3, 3], but got 2-dimensional input of size [128, 2048] instead