Misaligned address issue when using half precision

I simply call model.half() and change my input data from float32 to float16. However, the model will throw RuntimeError: cuda runtime error (74) : misaligned address error at /opt/conda/conda-bld/pytorch_1556653183467/work/aten/src/ATen/native/cuda/SoftMax.cu:545. The exact error address will change at each time I retrain the model.

Do you have a code snippet to reproduce this issue?
We’ve seen a similar issue recently in apex/amp, but could not reproduce it due to lack of information.

I guess the issue happens between the dataParallel wrapper and the half precision. I have tried apex/amp and the same issue exists except the O1 optimization. The issue may be reproduced simply with a linear projection. The error’s position is not fixed, sometimes happens at the very first linear layer, I guess you can replicate it with simply a linear layer, > 1 GPUs and DataParallel wrapper.

model, optimizer = amp.initialize(model, optimizer,
                                  opt_level='O2'
                                  )
model = torch.nn.DataParallel(model, device_ids=[0,1,2])

Could you rerun the code with DistributedDataParallel?

@JizeCao @ptrblck
I am encountering the same problem. Have you found any solution to the problem?

The error RuntimeError: CUDA error: misaligned address is thrown when float16 is used together with multiple GPUs.
Specifically, I am using nn.DataParallel() with torch.cuda.amp.GradScaler().

There is this related issue in GitHub.

Thanks in advance,

As described in the linked issue, could you create a new issue with steps how to reproduce it and tag me there, please?

@ptrblck Sorry for the late reply

I tried to reproduce the error with a simple vanilla network, but the error did not arise in such case.
So I suspected that it must have been thrown from a code error or some hardware specific thing.

But later on, I decreased the batch size and the error does not appear now, and I’m able to train correctly the model.
Therefore, I think it was probably an issue with the memory, that it run out of memory or something like that.

Thanks for the support!