Inject dropout into resnet (or any other network)

So i want to inject dropout into a (pretrained) resnet, as i get pretty bad over-fitting. (for example add a dropout layer after each residual step)

I guess that i could simply monkey patch the resnet base class. However i wonder if there is easy way to achieve this?

7 Likes

I think the cleanest approach would be to create your custom resnet class by deriving from the torchvision implementation and overriding the forward method.

2 Likes

With Res net in particular the derivation is even easier as the actual layers are already implemented in a separate class that gets passed to a basic resnet.

I was hopping that there is a general approach, that i could apply to multiple models. The standard models typically donā€™t contain drop out, as they are usually trained with big datasets. But for small Training datasets which are pretty common in practice, dropout helps a lot. So having a function that would adds dropout before/after each relu would be very useful.

model_with_dropout = add_dropout(model, after=ā€œreluā€)

Alternatively to my proposed approach you could also use forward hooks and add dropout at some layers. Iā€™m not sure, if itā€™s the best approach, but it might be similar to your mentioned approach.

Wouldnā€™t that also make the network apply dropout during testing/inference? Edit: maybe adding something like this.

1 Like

It depends of course on your implementation, but you could use the linked training argument as:

model = models.resnet18()
model.fc.register_forward_hook(lambda m, inp, out: F.dropout(out, p=0.5, training=m.training))
2 Likes

Hi @ptrblck, could you tell me where I should put that hook? In model.__init__ or model.forward or in training loop or somewhere else?

You could register the hook once after the model was initialized, which would be before the training loop.
Inside the __init__ would also work. I wouldnā€™t register it in the forward, as you would re-register a hook in each iteration.

1 Like

EDIT: NVM, found this discussion.

@ptrblck I went this solution you posted elsewhere. However, if I am reading torch.nn.functional dropout() correctly, this doesnā€™t apply the weight scaling inference rule:

From Ref: ā€œDeep Learningā€ Section 7.12 - Dropout:

A key insight (Hinton et al., 2012c) involved in dropout is that we can approximate p_ensemble by evaluating p(y | x) in one model: the model with all units, but with the weights going out of unit i multiplied by the probability of including unit i. The motivation for this modiļ¬cation is to capture the right expected value of the output from that unit. We call this approach the weight scaling inference rule. [ā€¦]

Because we usually use an inclusion probability of 1/2, the weight scaling rule usually amounts to dividing the weights by 2 at the end of training, and then using the model as usual. Another way to achieve the same result is to multiply the states of the units by 2 during training. Either way, the goal is to make sure that the expected total input to a unit at test time is roughly the same as the expected total input to that unit at train time, even though half the units at train time are missing on average.

Or am I missing something?

1 Like

The scaling is applied internally during training as seen here:

x = torch.ones(1, 10)
print(F.dropout(x, p=0.5, training=True))
> tensor([[2., 2., 0., 2., 2., 0., 2., 0., 2., 0.]])
print(F.dropout(x, p=0.5, training=False))
> tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
4 Likes

what are your parameters m and F please?

m would be the first argument to the lambda function, so the module, while F is the functional API imported as import torch.nn.functional as F.

2 Likes

@fiendfish
I know this is an old post, but I need to do a similar thing for a trained Resnet18 model to mitigate overfitting or premature convergence.
(will continue training as it is a Federated learning training, but need to do it at certain points of training as a regularization),
Could you please share how you ended up doing it and how useful it was?
Thank you!

Although this is an older thread I wanted to add the info that the resnets in the timm library have support for Dropout and DropBlock already implemented.
As it is not that well documented I thought it might save others some time if they are searching for this as well.
With timm.create_model(..., drop_rate=, drop_block_rate=) the droupout can be configured.

1 Like