AdamW + ZeroRedundancyOptimizer + Weight Decay Dictionary

I’m wondering if there’s a simple solution to translate:

optimizer = optim.AdamW(
                [
                    {"params": gain_or_bias_params, "weight_decay": 0.},
                    {"params": rest_params, "weight_decay": args.wd},
                ],
                lr=args.lr,
                betas=(args.beta1, args.beta2),
                eps=args.eps,
            )

Into a ZeroRedundancyOptimizer. The class by default expects an iterable of tensors but I’m passing an iterable of dictionaries which is causing a lot of headaches. I’ve tried a lot of hacky solutions but to no avail do to assertions that occur later on.

The expected:

optimizer = ZeroRedundancyOptimizer(
                [
                    {"params": gain_or_bias_params, "weight_decay": 0.},
                    {"params": rest_params, "weight_decay": args.wd},
                ],
                optim.AdamW,
                lr=args.lr,
                betas=(args.beta1, args.beta2),
                eps=args.eps,
            )

sadly doesn’t cut it.

Thank you for your time,
Cade

The trick is parameter groups!

optimizer = ZeroRedundancyOptimizer(
                rest_params,
                optim.AdamW,
                lr=args.lr,
                betas=(args.beta1, args.beta2),
                weight_decay=args.wd,
                eps=args.eps,
            )
optimizer.add_param_group({"params": gain_or_bias_params, "weight_decay": 0.})

Awesome, glad you found a solution! But it would be ideal for ZeRo to support this natively. I filed an issue as such: [ZeRo] Parameter group support in constructor · Issue #71347 · pytorch/pytorch · GitHub