Looking for a deep learning performance tips and tricks

I’m a happy owner of a 4090 and I’m about to train simple classification network on 2 million 224x224 images (vision-only), I’m using a fairly huge model and I want to speed-up my training setup.

I’ve been in deep learning for a while, but I still haven’t found any comprehensive guide on how to train models real fast, so here is what I know:

  1. use fp16 (GradScaler), 2x memory consumption reduction and some speed-up dependent on the architecture
  2. how to user tensor cores? I have no idea if they’re being used at all, how to enable them, how to monitor them? do they become automatically enabled when I use fp16?
  3. I’ve also tried torch.compile, it brought ~10% improvement in speed since pytorch doesn’t support flash attention for cards aside from A100/H100 as I understand (is it true though?)

thats it really, also I have some assumptions on what might work,
and here I want you to share your personal experience

  1. Is there any benefit using bfloat16 instead of float16?
  2. Did you try using memory_format.channel_last?

and maybe you know some other tricks and tips on how to speed-up training?
worth to mention, that I don’t have any CPU bottleneck, thats for sure (I’m using Ryzen 9 5950x with Samsung 980 pro 2TB NVME)

I would recommend taking a look at our performance guide for general tips to speed up your model training.
E.g. channels-last should be beneficial for mixed-precision training, so you might want to enable it. Also, torch.backends.cudnn.benchmark = True could yield another speedup (assuming you are using static shapes or a limited range of variable input shapes). These points are explained in a bit more detail in the linked docs.

1 Like