I saw from this post that within batchnorm layers the mean and std are buffers and not parameters. Does this mean they are handled automatically by dataparallels? I’m working on how to have dataparallels handle some of my buffers and would really like it to be done automatically as well rather than building custom DataParallel and DistributedDataParallel.
Just found this in the dataparallel documentation. So does this mean they’re not actually handled and its just only using the mean and std that are calculated being calculated on device[0] and deciding that’ll be close enough to the full mean and std, ignoring whats happening on other devices? And most importantly, even if it’s ignoring the other values for calculation, will the other devices be using the mean and std calculated on device 0 or they will just have their mean and std not ever be initialized?
.. warning::
In each forward, :attr:`module` is **replicated** on each device, so any
updates to the running module in ``forward`` will be lost. For example,
if :attr:`module` has a counter attribute that is incremented in each
``forward``, it will always stay at the initial value because the update
is done on the replicas which are destroyed after ``forward``. However,
:class:`~torch.nn.DataParallel` guarantees that the replica on
``device[0]`` will have its parameters and buffers sharing storage with
the base parallelized :attr:`module`. So **in-place** updates to the
parameters or buffers on ``device[0]`` will be recorded. E.g.,
:class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm`
rely on this behavior to update the buffers.
nn.DataParallel
will update the running stats on device0
and will scatter the state_dict
to all replicas in the forward pass.
This actually doesn’t seem to be working now that I’m trying to use a default DataParallel instead of my custom one. One part of my code is simply:
module.buffer = module.buffer + 1.0
print('counter')
print(module.buffer)
This works on one GPU and it works with my custom DataParallel that gathers the replica values to the main module. But with a default DataParallel is just repeatedly prints 1 as if its starting at 0 each time. If it matters this counter buffer is incremented within a torch.autograd.Function with a custom backward()
I added a
print(id(module.buffer))
Which seemed to show that the buffer was different on every batch. But this pointed me to the solution that doing it in place seems to correct the problem:
module.buffer += 1.0
But still curious why this would be required. And I have other buffers that I’m not sure if I can do in place operations for since they are being set the values of other buffers.