Assume we have two pytorch models M1 and M2. Each one are made of two fully connected layers. The first layer in these two models are shared.

The models are trained in parallel. A batch of data is fed into the first layer and then the output is fed into the second layer of each network to produce o1 and o2 (i.e., the outputs of the first and second networks).

We define the loss function L to be:

L1 = Criterion(o1, y1)

L2 = Criterion(o2, y2)

L = L1 + L2 + Diff_grad

y1 and y2 are the true labels for the two networks and Criterion() is binary cross entropy.

Diff_grad is the difference between the gradients of the two losses with respect to the first layer in the networks. More specifically it is: (gradient_of_L1_wrt_layer1 - gradient_of_L2_wrt_layer1)^2.

In plain language, I am trying to minimize the two losses, as well as the difference between the gradients of the two losses with respect to the first layer.

Can anyone tell me how to do this (preferably by a piece of code)? Thanks.

PS. the actually task that I am doing is bigger than this, the networks are much bigger, and they have more shared modules. Also the objective is not the “difference between the gradients”, it is a more complicated function. The example above is for my understanding.

I have written the following snippet (it is a bit difference from what I explained earlier, but still follows the same structure).
The variable x0 is shared between loss_1 and loss_2. At the end the trained variables are x0, x1, and x2.

The variable “reg” is what I am trying to use in the final loss function, but it looks it is not affecting the solution, even if I drop it from the final sum (loss_1 + loss_2 + reg) nothing changes.

inp = torch.Tensor([10.0, 15.0])
x0 = Variable(torch.Tensor([324.0, 456.0]), requires_grad=True)
x1 = Variable(torch.Tensor([56.0, 4545.0]), requires_grad=True)
x2 = Variable(torch.Tensor([1232.0, 89.0]), requires_grad=True)
target_1 = torch.Tensor([2.0])
target_2 = torch.Tensor([5.0])
optimizer = SGD([x0, x1, x2], lr=0.1)
def zero_grad_(x0, x1, x2):
if x0.grad is not None:
x0.grad.detach_()
x0.grad.zero_()
if x1.grad is not None:
x1.grad.detach_()
x1.grad.zero_()
if x2.grad is not None:
x2.grad.detach_()
x2.grad.zero_()
for ind in range(0, 50):
zero_grad_(x0, x1, x2)
loss_1_temp_1 = torch.sum(torch.relu(inp * x0) * x1)
loss_1 = torch.pow(loss_1_temp_1 - target_1, 2)
loss_2_temp_1 = torch.sum(torch.relu(inp * x0) * x2)
loss_2 = torch.pow(loss_2_temp_1 - target_2, 2)
loss_1_grad_wrt_x0 = torch.autograd.grad(loss_1, x0, retain_graph=True)[0]
loss_2_grad_wrt_x0 = torch.autograd.grad(loss_2, x0, retain_graph=True)[0]
reg = torch.mean(torch.pow(loss_1_grad_wrt_x0 - loss_2_grad_wrt_x0, 2))
loss = loss_1 + loss_2 + reg
loss.backward()
optimizer.step()
# print x0, x1, x2