What does `foreach` do in AdamW?

Hello! AdamW has a foreach parameter, which states :

foreach (bool, optional) – whether foreach implementation of optimizer is used (default: None)

I tried search for what the “foreach implementation” is, but couldn’t find it. Could anyone explain what this is?

Thank you in advance!

there are some functions with the prefix of _foreeach_ such as torch._foreach_exp and torch._foreach_add that take one or more lists of tensors. They apply some counterpart native function such as torch.exp and torch.add to each element of input tensor(s). If a certain condition is met such as tensors are hosted on the same device and of the same dtype then we can expect much less CUDA kernel calls than just iterating over the input lists and call torch functions.

5 Likes

I see! Thank you for your classification.

Hey, thank you for your answer. Could you give a source, where torch._foreach_add is explained in greater detail?

  • Is it an inplace operation for one of the two arguments?
  • Is there any way to replace this operation by Pytorch functions, looping directly over both lists?
1 Like

@Max_Unhold late response, not sure if still relevant, but will answer for the record:

The foreach ops are private in PyTorch today, so the best way I use to check what’s supported is by looking through native_functions.yaml: pytorch/aten/src/ATen/native/native_functions.yaml at main · pytorch/pytorch · GitHub

Is it an inplace operation for one of the two arguments?

It can be! torch._foreach_add is out of place, but torch._foreach_add_ (note the underscore) is inplace on the leftmost argument.

Is there any way to replace this operation by Pytorch functions, looping directly over both lists?

If I understand the question correctly, yes! foreach_add is a performance optimization over add. Semantically, foreach_blah should boil down to doing blah in a for loop. To give an example, cs = torch._foreach_add(as, bs) should compute similarly to cs = [a + b for (a, b) in zip(as, bs)]. (Don’t quote me on the python haha I didn’t actually run the code.)