I’m wondering if there’s a simple solution to translate:
optimizer = optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
Into a ZeroRedundancyOptimizer
. The class by default expects an iterable of tensors but I’m passing an iterable of dictionaries which is causing a lot of headaches. I’ve tried a lot of hacky solutions but to no avail do to assertions that occur later on.
The expected:
optimizer = ZeroRedundancyOptimizer(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
optim.AdamW,
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
sadly doesn’t cut it.
Thank you for your time,
Cade