Repytorch replace the batch norm layer in resnet50 with group norm layer

Hi, can anybody tell me how to replace the batch norm layer in resnet50 with group norm layer using pytorch?

To start I would take a look at the existing reference implementations in torchvision torchvision.models.resnet — Torchvision main documentation
and see if you can simply modify the handling of the norm_layer argument to handle GroupNorm (since it requires specifying the number of groups in addition to the number of channels)l