I am training a custom implementation of NVAE (GitHub - NVlabs/NVAE: The Official PyTorch Implementation of "NVAE: A Deep Hierarchical Variational Autoencoder" (NeurIPS 2020 spotlight paper)) in pytorch.
I am running a distributed environment with 4 gpus and DDP, using SyncBatchNorm as normalization.
Example of SyncBatchNorm initialization in a residual block of my autoencoder:
class ResidualCellEncoder(nn.Module):
def __init__(self, in_channels: int, out_channels: int, downsampling: bool, use_SE: bool):
"""
FIG 4.B of the paper
:param in_channels:
:param out_channels:
:param downsampling:
:param use_SE:
"""
super().__init__()
# skip connection has convs if downscaling
if downsampling:
stride = 2
self.skip_connection = SkipDown(in_channels, out_channels, stride)
else:
stride = 1
self.skip_connection = nn.Identity()
# (BN - SWISH) + conv 3x3 + (BN - SWISH) + conv 3x3 + SE
# downsampling in the first conv, depending on stride
self.residual = nn.Sequential(
SyncBatchNorm(in_channels, eps=1e-5, momentum=0.05),
SiLU(),
Conv2D(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True, weight_norm=True),
SyncBatchNorm(out_channels, eps=1e-5, momentum=0.05),
SiLU(),
Conv2D(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, weight_norm=True)
)
if use_SE:
self.residual.append(SE(out_channels, out_channels))
def forward(self, x: torch.Tensor):
residual = 0.1 * self.residual(x)
x = self.skip_connection(x)
return x + residual
I have other SyncBatchNorm in the Autoencoder, all initialized like above (with custom eps and momentum).
Here Below, the code that I use to save the model
# Save checkpoint (after validation)
if WORLD_RANK == 0:
print(f'[INFO] Saving Checkpoint')
checkpoint_dict = {
'epoch': epoch + 1,
'global_step': global_step + 1,
'configuration': config,
'state_dict': ddp_model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'grad_scaler': grad_scalar.state_dict()
}
ckpt_file = f"{args.checkpoint_base_path}/{args.run_name}/epoch={epoch:02d}.pt"
torch.save(checkpoint_dict, ckpt_file)
The issue is the following: the running_mean
and running_var
parameters are not saved correcly in the state dict.
I first noticed something wrong when trying to load a pre-trained model to continue training. The loss value of the first batch after loading is much higher than the last loss value registered before training stopped.
To confirm that the problem is SyncBatchNorm, I tried to manually save each running_mean
and running_var
immediately after saving the state dict:
import numpy as np
for i, (n, layer) in enumerate(ddp_model.module.named_modules()):
if isinstance(layer, torch.nn.SyncBatchNorm):
file_name = f"{args.checkpoint_base_path}/{args.run_name}/rank={WORLD_RANK}_I={i}.npy"
np.save(file_name, torch.stack([layer.running_mean, layer.running_var]).cpu().numpy())
Then, in a new file, I load all the numpy files and the state dict, and run tests to assert two different things:
- that each rank saves the same numpy arrays (means/vars)
- that the means and vars saved in the
rank_0
numpy arrays are the same as the ones in state dict.
Here is the code:
input_files = sorted(os.listdir('./cifar10_4x3_test'))
ckpt_file = None
rank_0, rank_1, rank_2, rank_3 = [], [], [], []
for file in input_files:
if file.endswith('.npy'):
ar = torch.tensor(np.load('./cifar10_4x3_test/' + file))
if 'rank=0' in file:
rank_0.append(ar)
elif 'rank=1' in file:
rank_1.append(ar)
elif 'rank=2' in file:
rank_2.append(ar)
else:
rank_3.append(ar)
else:
ckpt_file = './cifar10_4x3_test/' + file
# assert sync batch norm works (all ranks have the same stuff)
for r0, r1, r2, r3 in zip(rank_0, rank_1, rank_2, rank_3):
assert r0[0].sum() == r1[0].sum() == r2[0].sum() == r3[0].sum()
assert r0[1].sum() == r1[1].sum() == r2[1].sum() == r3[1].sum()
# assert what you load from checkpoint is correct
checkpoint = torch.load(ckpt_file, map_location='cpu')
config = checkpoint['configuration']
model = AutoEncoder(config['autoencoder'], config['resolution'])
model.load_state_dict(checkpoint['state_dict'])
count_n = 0
for _, layer in model.named_modules():
if isinstance(layer, torch.nn.SyncBatchNorm):
mean = layer.running_mean
var = layer.running_var
assert rank_0[count_n][0] == mean
assert rank_0[count_n][1] == var
count_n += 1
The first assert passes, the second doesn’t!
This means that the running means and vars saved as numpy tensors and the ones saved by torch save are different. I guess that something is wrong in my saving/loading procedure, can someone help ?
Thank you