I have a standard multi-task DNN model with 2 shared layers followed by a single layer for each of two separate but related regression tasks (one main task, one auxiliary task). My question is regarding how I could manage to selectively backpropagate gradients from the auxiliary task based on some condition. A relevant paper in this space is titled: Adapting Auxiliary Losses Using Gradient Similarity, wherein the authors backpropagate gradients from the auxiliary task only when they are in a similar direction to the main task.
Here is the implementation I have so far
Multi-Task DNN Model
class Net(nn.Module):
def __init__(self,inputsize,hiddensize,outputsize_main,outputsize_aux):
super(Net,self).__init__()
self.fc_shared1 = nn.Linear(inputsize,hiddensize)
self.fc_shared2 = nn.Linear(hiddensize,hiddensize)
self.fc_main1 = nn.Linear(hiddensize,outputsize_main)
self.fc_aux1 = nn.Linear(hiddensize,outputsize_aux)
self.tanh1 = nn.Tanh()
self.tanh2 = nn.Tanh()
def forward(self,x):
out = self.fc_shared1(x)
out = self.tanh1(out)
out = self.fc_shared2(out)
out = self.tanh2(out)
#Main
outmain = self.fc_main1(out)
#Aux
outaux = self.fc_aux1(out)
return outmain, outaux
Simple Hook Script (borrowed from here)
# A simple hook class that returns the input and output of a layer during forward/backward pass
class Hook():
def __init__(self, module,name, backward=False):
self.name = name #Name of the layer. Useful for debugging.
if backward==False:
self.hook = module.register_forward_hook(self.hook_fn)
else:
self.hook = module.register_backward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.input = input
self.output = output
def close(self):
self.hook.remove()
Hyperparameters
INPUTSIZE=4
OUTPUTSIZE_MAIN=1
OUTPUTSIZE_AUX=1
HIDDENSIZE=16
BATCHSIZE=32
NUMEPOCHS=100
learningrate=0.001
model=Net(INPUTSIZE,HIDDENSIZE,OUTPUTSIZE_MAIN,OUTPUTSIZE_AUX)
optimizer=torch.optim.Adam(model.parameters(),lr=learningrate)
criterion=nn.MSELoss()
# trainX, testX are n X 4 matrices containing input features where `n` is the number of instances.
# trainY, testY are n X 1 arrays containing target values for the main task.
# trainAuxY , testAuxY are also n X 1 arrays containing the auxiliary variable target values.
train_dataset = SampleDataLoader(trainX,trainY,trainAuxY)
train_loader = DataLoader(train_dataset,batch_size=BATCHSIZE)
test_dataset = SampleDataLoader(testX,testY,testAuxY)
test_loader = DataLoader(test_dataset,batch_size=BATCHSIZE)
Training Loop
hookF = [Hook(layer,name) for name,layer in list(model._modules.items())]
hookB = [Hook(layer,name,backward=True) for name,layer in list(model._modules.items())]
losses=list()
for ep in range(NUMEPOCHS):
losses_epoch=list()
for batchX,batchY,batchAuxY in train_loader:
optimizer.zero_grad()
_X=torch.autograd.Variable(batchX.float()).to(device)
_Y=torch.autograd.Variable(batchY.float()).to(device)
_AuxY=torch.autograd.Variable(batchAuxY.float()).to(device)
outMain,outAux=model(_X)
lossMain = criterion(outMain,_Y)
lossAux = criterion(outAux,_AuxY)
loss = lossMain + lossAux
losses_epoch.append(loss.item())
loss.backward()
optimizer.step()
print('***'*3+' Forward Hooks Inputs & Outputs '+'***'*3)
for hook in hookF:
print(hook.name)
print("Num Items Input Forward= {} ".format(len(hook.input)))
print("Size of Each Input = {}".format([i.size() for i in hook.input]))
print("Num Items Output Forward= {}".format(len(hook.output)))
print("Size of Each Output = {}".format([o.size() for o in hook.output]))
print('---'*17)
print('\n')
print('***'*3+' Backward Hooks Inputs & Outputs '+'***'*3)
for hook in hookB:
print(hook.name)
print("Num Items Input Backward= {} ".format(len(hook.input)))
print("Size of Each Input = {}".format([i.size() for i in hook.input]))
print("Num Items Output Backward= {}".format(len(hook.output)))
print("Size of Each Output = {}".format([o.size() for o in hook.output]))
print('---'*17)
break
print("=============================================\n\n")
losses.append(np.mean(losses_epoch))
I have the following questions that I haven’t found a clear / convincing answer to online:
-
What is the difference between using
loss.backward()
vs.out.backward()
?out
(ofcourse in the above model it would beoutmain
,outaux
) is the output of the call tomodel.forward()
. I have read (and maybe misunderstood) thatloss.backward()
doesn’t work with hooks (although I didn’t see any errors being thrown by the hooks when I ran the aforementioned code withloss.backward()
. So is it okay to useloss.backward()
with hooks? -
How can I isolate the two specific sets of gradients to be compared, i.e in this case the gradients are
d(lossMain)/d(fc_shared1)
,d(lossAux)/d(fc_shared1)
andd(lossMain)/d(fc_shared2)
,d(lossAux)/d(fc_shared2)
. The way I see it, my current implementation will only returnd(loss)/d(fc_shared1)
andd(loss)/d(fc_shared2)
, should I keep the losses separate and invokelossMain.backward()
andlossAux.backward()
separately or is there some other solution? In the case of separate invocation ofbackward()
on losses, would the way losses are backpropagated through the network change in anyway w.r.t the single loss case i.eloss = lossMain + lossAux
? -
Once gradients are isolated and I am able to compare them, will zeroing out
d(lossAux)/d(fc_shared1)
,d(lossAux)/d(fc_shared2)
be enough to ensure that the gradient from the auxiliary loss won’t propagate to the shared layer(s)? The way I envision doing this is, zeroing outfc_shared2.autograd.grad
orfc_shared1.autograd.grad
before the call tooptimizer.step()
at times whenfc_shared2.autograd.grad
andfc_shared1.autograd.grad
satisfies some condition? -
For now, I am just trying to print out, the inputs and outputs to the forward and backward hook functions to better understand them but I am having a hard time understanding them.
For example, here is the result of printing both the inputs and outputs of the forward hook function for each layer:
********* Forward Hooks Inputs & Outputs *********
********* Forward Hooks Inputs & Outputs *********
fc_shared1
Num Items Input Forward= 1
Size of Each Input = [torch.Size([14, 4])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16])]
---------------------------------------------------
fc_shared2
Num Items Input Forward= 1
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16])]
---------------------------------------------------
fc_main1
Num Items Input Forward= 1
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1])]
---------------------------------------------------
fc_aux1
Num Items Input Forward= 1
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1]), torch.Size([1])]
---------------------------------------------------
tanh1
Num Items Input Forward= 1
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16])]
---------------------------------------------------
tanh2
Num Items Input Forward= 1
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Forward= 14
Size of Each Output = [torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16])]
---------------------------------------------------
********* Backward Hooks Inputs & Outputs *********
fc_shared1
Num Items Input Backward= 3
Size of Each Input = [torch.Size([16]), -1, torch.Size([4, 16])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 16])]
---------------------------------------------------
fc_shared2
Num Items Input Backward= 3
Size of Each Input = [torch.Size([16]), torch.Size([14, 16]), torch.Size([16, 16])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 16])]
---------------------------------------------------
fc_main1
Num Items Input Backward= 3
Size of Each Input = [torch.Size([1]), torch.Size([14, 16]), torch.Size([16, 1])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 1])]
---------------------------------------------------
fc_aux1
Num Items Input Backward= 3
Size of Each Input = [torch.Size([1]), torch.Size([14, 16]), torch.Size([16, 1])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 1])]
---------------------------------------------------
tanh1
Num Items Input Backward= 1
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 16])]
---------------------------------------------------
tanh2
Num Items Input Backward= 1
Size of Each Input = [torch.Size([14, 16])]
Num Items Output Backward= 1
Size of Each Output = [torch.Size([14, 16])]
---------------------------------------------------
The Num Items Output Forward
and Num Items Input Forward
i.e the size of the tuples stored in input
and output
seem to be some function of the BATCHSIZE
i.e the number changes when I change with BATCHSIZE
but I wasn’t able to find how this is calculated.
Also, if someone could clearly elucidate exactly what the input
,output
and grad_input
, grad_output
parameters do in the case of the forward and backward hooks respectively, it would help me greatly in my task of working with hooks
in pytorch.