Can operators in hook back-propagate gradients

Suppos I have a network with three FC layer:

class Net(nn.Module):
def __init__(self):
    super(net, self).__init__()
    self.fc1 = nn.Linear(10, 10)
    self.fc2 = nn.Linear(10, 20)
    self.net = nn.Sequential(fc1,fc2)
    self.fc3 = nn.Linear(10, 30)

And I register a forward hook for the first layer:

def hook_func(module, input, output):
    output = net.fc3(output)  # in hook, there is a linear operator
    object.feature = output  # I pass the output to a member variable of an object

net.fc1.register_forward_hook(hook_func)

Then in forward progress:

net_out = net.fc2(net.fc1(input))
loss1 = L1(input, label1)
loss2 = L2(object.feature, label2)
loss = loss1 + loss2

I want to ask is that, when I call:

loss.backward()

Can the gradient of loss2 back-propagate through the hook? Or say, is the gradient the same with:

out1 = net.fc1(input)
out2 = net.fc2(out1)
out3 = net.fc3(out1)
loss1 = L1(out2, lable1)
loss2 = L2(out3, label2)
loss = loss1 + loss2
loss.backward()

When I wrap the net with nn.DistributedDataParallel, I got an error:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable). (prepare_for_backward at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:518)

It seems that the output in hook is not counted in DDP. The error disappeared when I set find_unused_parameters=True. But I do not know whether the gradients are backpropagated correctly?

I have test it with this script and the conclusion is that hook can back-propagate gradients.
The error mentioned in question is caused by a small bug

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 20)
        self.fc3 = nn.Linear(10, 30)

    def forward(self, x):
        x = self.fc1(x)
        x1 = self.fc2(x)
        x2 = self.fc3(x)
        return x1, x2


class Warp:
    def __init__(self, model):
        self.feature = []
        self.model = model

    def register_hook(self):

        def hook_func(module, input, output):
            output = self.model.module.fc3(output[0])
            self.feature = output

        self.model.module.fc1.register_forward_hook(hook_func)

    def step(self, x):
        x1 = self.model.module.fc2(self.model.module.fc1(x))
        x2 = self.feature
        return x1, x2


# situation 1:


def inter_main(local_rank):
    print('start rank: ' + str(local_rank))
    cur_device = torch.cuda.current_device()
    net = Net().cuda()
    net = torch.nn.parallel.DistributedDataParallel(
        module=net, device_ids=[cur_device], output_device=cur_device, find_unused_parameters=True)
    x = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).cuda().float()
    warp = Warp(net)
    warp.register_hook()
    # baseline
    out1, out2 = net(x)
    out = out1.mean() + out2.mean()
    out.backward()
    if local_rank == 0:
        grad_baseline = net.module.fc1.weight.grad.clone()

    out1, out2 = net(x)
    out = out1.mean() + out2.mean()
    out.backward()
    if local_rank == 0:
        grad_baseline2 = net.module.fc1.weight.grad.clone() - grad_baseline

    # hook
    out1, out2 = warp.step(x)
    out = out1.mean() + out2.mean()
    out.backward()
    if local_rank == 0:
        grad_hook = net.module.fc1.weight.grad.clone() - grad_baseline - grad_baseline2
        print(out)
        print(torch.stack((grad_hook, grad_baseline, grad_baseline2), dim=-1))


def inter_main_bk(local_rank):
    print('start rank: ' + str(local_rank))
    cur_device = torch.cuda.current_device()
    net = Net().cuda()
    net = torch.nn.parallel.DistributedDataParallel(
        module=net, device_ids=[cur_device], output_device=cur_device, find_unused_parameters=True)
    x = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).cuda().float()
    out1, out2 = net(x)
    out = out1.mean() + out2.mean()
    out.backward()
    if local_rank == 0:
        print(net.module.fc1.weight.grad)


def run(
        local_rank, num_proc, func, init_method, shard_id, num_shards, backend
):
    """
    Runs a function from a child process.
    Args:
        local_rank (int): rank of the current process on the current machine.
        num_proc (int): number of processes per machine.
        func (function): function to execute on each of the process.
        init_method (string): method to initialize the distributed training.
            TCP initialization: equiring a network address reachable from all
            processes followed by the port.
            Shared file-system initialization: makes use of a file system that
            is shared and visible from all machines. The URL should start with
            file:// and contain a path to a non-existent file on a shared file
            system.
        shard_id (int): the rank of the current machine.
        num_shards (int): number of overall machines for the distributed
            training job.
        backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
            supports, each with different capabilities. Details can be found
            here:
            https://pytorch.org/docs/stable/distributed.html
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """
    # Initialize the process group.
    world_size = num_proc * num_shards
    rank = shard_id * num_proc + local_rank

    try:
        torch.distributed.init_process_group(
            backend=backend,
            init_method=init_method,
            world_size=world_size,
            rank=rank,
        )
    except Exception as e:
        raise e

    torch.cuda.set_device(local_rank)
    func(local_rank)


def main():
    torch.multiprocessing.spawn(
        run,
        nprocs=2,
        args=(
            2,
            inter_main,
            "tcp://localhost:9999",
            0,
            1,
            'nccl',
        ),
        daemon=False,
    )


if __name__ == "__main__":
    torch.multiprocessing.set_start_method("spawn")
    main()