Controlling Gradient Propagation To Shared Layer From Multi-task Model

I have a standard multi-task DNN model with 2 shared layers followed by a single layer for each of two separate but related regression tasks (one main task, one auxiliary task). My question is regarding how I could manage to selectively backpropagate gradients from the auxiliary task based on some condition. A relevant paper in this space is titled: Adapting Auxiliary Losses Using Gradient Similarity, wherein the authors backpropagate gradients from the auxiliary task only when they are in a similar direction to the main task.

Here is the implementation I have so far

Multi-Task DNN Model

class Net(nn.Module):
  def __init__(self,inputsize,hiddensize,outputsize_main,outputsize_aux):
    super(Net,self).__init__()
    self.fc_shared1 = nn.Linear(inputsize,hiddensize)
    self.fc_shared2 = nn.Linear(hiddensize,hiddensize)
    self.fc_main1 = nn.Linear(hiddensize,outputsize_main)
    self.fc_aux1 = nn.Linear(hiddensize,outputsize_aux)
    self.tanh1 = nn.Tanh()
    self.tanh2 = nn.Tanh()

  def forward(self,x):
    out = self.fc_shared1(x)
    out = self.tanh1(out)
    out = self.fc_shared2(out)
    out = self.tanh2(out)

    #Main
    outmain = self.fc_main1(out)
    #Aux
    outaux = self.fc_aux1(out)

    return outmain, outaux

Simple Hook Script (borrowed from here)

# A simple hook class that returns the input and output of a layer during forward/backward pass
class Hook():
    def __init__(self, module,name, backward=False):
        self.name = name  #Name of the layer. Useful for debugging.
        if backward==False:
            self.hook = module.register_forward_hook(self.hook_fn)
        else:
            self.hook = module.register_backward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output
    def close(self):
        self.hook.remove()

Hyperparameters

INPUTSIZE=4
OUTPUTSIZE_MAIN=1
OUTPUTSIZE_AUX=1
HIDDENSIZE=16
BATCHSIZE=32
NUMEPOCHS=100
learningrate=0.001
model=Net(INPUTSIZE,HIDDENSIZE,OUTPUTSIZE_MAIN,OUTPUTSIZE_AUX)
optimizer=torch.optim.Adam(model.parameters(),lr=learningrate)
criterion=nn.MSELoss()

# trainX, testX are n X 4 matrices containing input features where `n` is the number of instances.
# trainY, testY are n X 1 arrays containing target values for the main task.
# trainAuxY , testAuxY are also n X 1 arrays containing the auxiliary variable target values.

train_dataset = SampleDataLoader(trainX,trainY,trainAuxY)
train_loader = DataLoader(train_dataset,batch_size=BATCHSIZE)
test_dataset = SampleDataLoader(testX,testY,testAuxY)
test_loader = DataLoader(test_dataset,batch_size=BATCHSIZE)

Training Loop

hookF = [Hook(layer,name) for name,layer in list(model._modules.items())]
hookB = [Hook(layer,name,backward=True) for name,layer in list(model._modules.items())]

losses=list()
for ep in range(NUMEPOCHS):
  losses_epoch=list()
  for batchX,batchY,batchAuxY in train_loader:
    optimizer.zero_grad()
    _X=torch.autograd.Variable(batchX.float()).to(device)
    _Y=torch.autograd.Variable(batchY.float()).to(device)
    _AuxY=torch.autograd.Variable(batchAuxY.float()).to(device)

    outMain,outAux=model(_X)
    lossMain = criterion(outMain,_Y)
    lossAux = criterion(outAux,_AuxY)
    loss = lossMain + lossAux
    losses_epoch.append(loss.item())
    loss.backward()
    optimizer.step()
  
  print('***'*3+'  Forward Hooks Inputs & Outputs  '+'***'*3)
  for hook in hookF:
      print(hook.name)            
      print("Num Items Input Forward= {} ".format(len(hook.input)))
      print("Size of Each Input = {}".format([i.size() for i in hook.input]))          
      print("Num Items Output Forward= {}".format(len(hook.output)))         
      print("Size of Each Output = {}".format([o.size() for o in hook.output]))
      print('---'*17)
  
  print('\n')
  print('***'*3+'  Backward Hooks Inputs & Outputs  '+'***'*3)
  for hook in hookB: 
      print(hook.name)            
      print("Num Items Input Backward= {} ".format(len(hook.input)))
      print("Size of Each Input = {}".format([i.size() for i in hook.input]))          
      print("Num Items Output Backward= {}".format(len(hook.output)))         
      print("Size of Each Output = {}".format([o.size() for o in hook.output]))
      print('---'*17)
      break
  print("=============================================\n\n")
  losses.append(np.mean(losses_epoch))

I have the following questions that I haven’t found a clear / convincing answer to online:

  1. What is the difference between using loss.backward() vs. out.backward()? out (ofcourse in the above model it would be outmain, outaux) is the output of the call to model.forward(). I have read (and maybe misunderstood) that loss.backward() doesn’t work with hooks (although I didn’t see any errors being thrown by the hooks when I ran the aforementioned code with loss.backward(). So is it okay to use loss.backward() with hooks?

  2. How can I isolate the two specific sets of gradients to be compared, i.e in this case the gradients are d(lossMain)/d(fc_shared1) , d(lossAux)/d(fc_shared1) and d(lossMain)/d(fc_shared2) , d(lossAux)/d(fc_shared2). The way I see it, my current implementation will only return d(loss)/d(fc_shared1) and d(loss)/d(fc_shared2), should I keep the losses separate and invoke lossMain.backward() and lossAux.backward() separately or is there some other solution? In the case of separate invocation of backward() on losses, would the way losses are backpropagated through the network change in anyway w.r.t the single loss case i.e loss = lossMain + lossAux ?

  3. Once gradients are isolated and I am able to compare them, will zeroing out d(lossAux)/d(fc_shared1), d(lossAux)/d(fc_shared2) be enough to ensure that the gradient from the auxiliary loss won’t propagate to the shared layer(s)? The way I envision doing this is, zeroing out fc_shared2.autograd.grad or fc_shared1.autograd.grad before the call to optimizer.step() at times when fc_shared2.autograd.grad and fc_shared1.autograd.grad satisfies some condition?

  4. For now, I am just trying to print out, the inputs and outputs to the forward and backward hook functions to better understand them but I am having a hard time understanding them.

For example, here is the result of printing both the inputs and outputs of the forward hook function for each layer:

*********  Forward Hooks Inputs & Outputs  *********
*********  Forward Hooks Inputs & Outputs  *********
fc_shared1
Num Items Input Forward= 1 
Size of Each Input = [torch.Size([14, 4])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16])]
---------------------------------------------------
fc_shared2
Num Items Input Forward= 1 
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16])]
---------------------------------------------------
fc_main1
Num Items Input Forward= 1 
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1])]
---------------------------------------------------
fc_aux1
Num Items Input Forward= 1 
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1])]
---------------------------------------------------
tanh1
Num Items Input Forward= 1 
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16])]
---------------------------------------------------
tanh2
Num Items Input Forward= 1 
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16])]
---------------------------------------------------


*********  Backward Hooks Inputs & Outputs  *********
fc_shared1
Num Items Input Backward= 3 
Size of Each Input = [torch.Size([16]), -1, torch.Size([4, 16])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 16])]
---------------------------------------------------
fc_shared2
Num Items Input Backward= 3 
Size of Each Input = [torch.Size([16]), torch.Size([14, 16]), torch.Size([16, 16])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 16])]
---------------------------------------------------
fc_main1
Num Items Input Backward= 3 
Size of Each Input = [torch.Size([1]), torch.Size([14, 16]), torch.Size([16, 1])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 1])]
---------------------------------------------------
fc_aux1
Num Items Input Backward= 3 
Size of Each Input = [torch.Size([1]), torch.Size([14, 16]), torch.Size([16, 1])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 1])]
---------------------------------------------------
tanh1
Num Items Input Backward= 1 
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 16])]
---------------------------------------------------
tanh2
Num Items Input Backward= 1 
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 16])]
---------------------------------------------------

The Num Items Output Forward and Num Items Input Forward i.e the size of the tuples stored in input and output seem to be some function of the BATCHSIZE i.e the number changes when I change with BATCHSIZE but I wasn’t able to find how this is calculated.

Also, if someone could clearly elucidate exactly what the input,output and grad_input, grad_output parameters do in the case of the forward and backward hooks respectively, it would help me greatly in my task of working with hooks in pytorch.