Weird output of forward hook when use multi-gpus

I have tried several ways to implement this in multi-gpus mode, one of them is not so elegent but works, which returns right intermedia feature map, however another one does’t work which returns None and false result.
I gauss the condition is revelent to the assignment self.target_outputs = output.detach() in hook function. What variable the forward output is assigned to matters.

# the one that works
class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super(Wrapper, self).__init__()
        self.model = model

        def f_hook(module, __, output):
            module.register_buffer('target_outputs', output)
        self.model.features[2].register_forward_hook(f_hook)

    def forward(self,input):
        self.model(input)
        return self.model.features[2].target_outputs

    def __repr__(self):
        return "Wrappper"


# another one that not works. The implemetation and the result is the same as what is stated in question part.
class Wrapper1(torch.nn.Module):
    def __init__(self, model):
        super(Wrapper1, self).__init__()
        self.model = model
        self.target_outputs = None

        def f_hook(_, __, output):
            self.target_outputs = output.detach()
        self.model.features[2].register_forward_hook(f_hook)

    def forward(self,input):
        self.model(input)
        return self.target_outputs

    def __repr__(self):
        return "Wrappper1"

# test code
if __name__ == '__main__':
    devices = [4,5]
  
    model = vgg19().cuda(4)
    model = model.cuda(4)
    wrapper = Wrapper(model)
    wrapper = wrapper.cuda(4)
    input1 = torch.randn(60,3,224,224).fill_(0).cuda(4)  
    out1 = torch.nn.parallel.data_parallel(wrapper, input1, devices)
    print(out1) if out1 is not None else None
    # print a right feature map
    input2 = torch.randn(60,3,224,224).fill_(1).cuda(4)  
    out2 = torch.nn.parallel.data_parallel(wrapper, input2, devices)
    print(out2) if out2 is not None else None
    # print a right feature map 

    model = vgg19().cuda(4)
    model = model.cuda(4)
    wrapper = Wrapper1(model)
    wrapper = wrapper.cuda(4)
    input1 = torch.randn(60,3,224,224).fill_(0).cuda(4)
    out1 = torch.nn.parallel.data_parallel(wrapper, input1, devices)
    print(out1) if out1 is not None else None
    # print None
    input2 = torch.randn(60,3,224,224).fill_(1).cuda(4)  
    out2 = torch.nn.parallel.data_parallel(wrapper, input2, devices)
    print(out2) if out2 is not None else None
    # print a false feature map, which corresponds to output of input1 rather than output of input2.
    # This is what confuses me all the time.

Need help! Is there someone could explain why this happened ? Thanks in advance.

1 Like