This issue was tentatively solved by solving a stride mismatch error eg. UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed. This is not an error, but may impair performance. grad.sizes() = [1, 4, 12], strides() = [1, 12, 1] bucket_view.sizes() = [1, 4, 12], strides() = [48, 12, 1] (Triggered internally at ../torch/csrc/distributed/c10d/reducer.cpp:327.)
Not sure why this would be related or it it is coincidental, but I am no longer experiencing memory explosion on new h100 jobs.