I have tried several ways to implement this in multi-gpus mode, one of them is not so elegent but works, which returns right intermedia feature map, however another one does’t work which returns None and false result.
I gauss the condition is revelent to the assignment self.target_outputs = output.detach()
in hook function. What variable the forward output is assigned to matters.
# the one that works
class Wrapper(torch.nn.Module):
def __init__(self, model):
super(Wrapper, self).__init__()
self.model = model
def f_hook(module, __, output):
module.register_buffer('target_outputs', output)
self.model.features[2].register_forward_hook(f_hook)
def forward(self,input):
self.model(input)
return self.model.features[2].target_outputs
def __repr__(self):
return "Wrappper"
# another one that not works. The implemetation and the result is the same as what is stated in question part.
class Wrapper1(torch.nn.Module):
def __init__(self, model):
super(Wrapper1, self).__init__()
self.model = model
self.target_outputs = None
def f_hook(_, __, output):
self.target_outputs = output.detach()
self.model.features[2].register_forward_hook(f_hook)
def forward(self,input):
self.model(input)
return self.target_outputs
def __repr__(self):
return "Wrappper1"
# test code
if __name__ == '__main__':
devices = [4,5]
model = vgg19().cuda(4)
model = model.cuda(4)
wrapper = Wrapper(model)
wrapper = wrapper.cuda(4)
input1 = torch.randn(60,3,224,224).fill_(0).cuda(4)
out1 = torch.nn.parallel.data_parallel(wrapper, input1, devices)
print(out1) if out1 is not None else None
# print a right feature map
input2 = torch.randn(60,3,224,224).fill_(1).cuda(4)
out2 = torch.nn.parallel.data_parallel(wrapper, input2, devices)
print(out2) if out2 is not None else None
# print a right feature map
model = vgg19().cuda(4)
model = model.cuda(4)
wrapper = Wrapper1(model)
wrapper = wrapper.cuda(4)
input1 = torch.randn(60,3,224,224).fill_(0).cuda(4)
out1 = torch.nn.parallel.data_parallel(wrapper, input1, devices)
print(out1) if out1 is not None else None
# print None
input2 = torch.randn(60,3,224,224).fill_(1).cuda(4)
out2 = torch.nn.parallel.data_parallel(wrapper, input2, devices)
print(out2) if out2 is not None else None
# print a false feature map, which corresponds to output of input1 rather than output of input2.
# This is what confuses me all the time.
Need help! Is there someone could explain why this happened ? Thanks in advance.