I have a transformer currently running with amp, which speeds up the computation, however there is close to no effect on memory usage as the primary bottleneck are the embedding weights.
I want to convert embeddings to fp16, while at the same time keeping the AMP. How can I achieve this?