I think the cleanest approach would be to create your custom resnet class by deriving from the torchvision implementation and overriding the forward method.
With Res net in particular the derivation is even easier as the actual layers are already implemented in a separate class that gets passed to a basic resnet.
I was hopping that there is a general approach, that i could apply to multiple models. The standard models typically donāt contain drop out, as they are usually trained with big datasets. But for small Training datasets which are pretty common in practice, dropout helps a lot. So having a function that would adds dropout before/after each relu would be very useful.
Alternatively to my proposed approach you could also use forward hooks and add dropout at some layers. Iām not sure, if itās the best approach, but it might be similar to your mentioned approach.
You could register the hook once after the model was initialized, which would be before the training loop.
Inside the __init__ would also work. I wouldnāt register it in the forward, as you would re-register a hook in each iteration.
A key insight (Hinton et al., 2012c) involved in dropout is that we can approximate p_ensemble by evaluating p(y | x) in one model: the model with all units, but with the weights going out of unit i multiplied by the probability of including unit i. The motivation for this modiļ¬cation is to capture the right expected value of the output from that unit. We call this approach the weight scaling inference rule. [ā¦]
Because we usually use an inclusion probability of 1/2, the weight scaling rule usually amounts to dividing the weights by 2 at the end of training, and then using the model as usual. Another way to achieve the same result is to multiply the states of the units by 2 during training. Either way, the goal is to make sure that the expected total input to a unit at test time is roughly the same as the expected total input to that unit at train time, even though half the units at train time are missing on average.
@fiendfish
I know this is an old post, but I need to do a similar thing for a trained Resnet18 model to mitigate overfitting or premature convergence.
(will continue training as it is a Federated learning training, but need to do it at certain points of training as a regularization),
Could you please share how you ended up doing it and how useful it was?
Thank you!
Although this is an older thread I wanted to add the info that the resnets in the timm library have support for Dropout and DropBlock already implemented.
As it is not that well documented I thought it might save others some time if they are searching for this as well.
With timm.create_model(..., drop_rate=, drop_block_rate=) the droupout can be configured.