I try to use amp with pytorch1.6 to speed up my training code.

But I have a problem ,when I use nn.DataParallel.

I print some Intermediate variable.

I find the tensor is float16 in one gpu, but float32 in two gpus.

Is it support DataParallel model to use mixed-precision training?

in one gpu:

fake_image_orig: torch.float16

gen loss: torch.float32

discriminator_out dtype: torch.float16

pred_fake: torch.float16

amp discriminor

discriminator_out dtype: torch.float16

self.get_zero_tensor(input) dtype: torch.float16

input dtype: torch.float16

self.get_zero_tensor(input) dtype: torch.float16

two gpus:

discriminator_out dtype: torch.float32

self.get_zero_tensor(input) dtype: torch.float32

input dtype: torch.float32

self.get_zero_tensor(input) dtype: torch.float32

input dtype: torch.float32

discriminator_out dtype: torch.float32

self.get_zero_tensor(input) dtype: torch.float32

input dtype: torch.float32

self.get_zero_tensor(input) dtype: torch.float32

input dtype: torch.float32

fake_image_orig: torch.float32

gen loss: torch.float32

discriminator_out dtype: torch.float32

pred_fake: torch.float32

fake_image_orig: torch.float32

gen loss: torch.float32

discriminator_out dtype: torch.float32

pred_fake: torch.float32

fake_image_orig: torch.float32

gen loss: torch.float32

fake_image_orig: torch.float32

gen loss: torch.float32

discriminator_out dtype: torch.float32