I am using resnet50. I have to introduce an additional model to resnet50, and the new model uses 3d Conv. Hence I have to first modify the intermediate layer input by reshaping and then pass it to the new model and reshape its output and pass it to further layers of resnet50. Hence I decided to use hooks. Should I modify it like below code stamp? Does the PyTorch backprop work fine as usual? It will automatically compute the gradient for NEWMODEL
by considering it as a part of the older model.
def new_model_insertion(name):
def hook(model, input, output):
return NEWMODEL(output)
return hook
Thanks a lot.