DropConnect implementation

Can someone point out what are the advantages of this implementation of DropConnect over a simpler method like this:

for i in range(num_batches):

    orig_params = []
    for n, p in model.named_parameters():
        orig_params.append(p.clone())
        p.data = F.dropout(p.data, p=drop_prob) * (1 - drop_prob)

    output = model(input)

    for orig_p, (n, p) in zip(orig_params, model.named_parameters()):  
        p.data = orig_p.data

    loss = nn.CrossEntropyLoss()(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
1 Like

Some thoughts:

Don’t use .data these days! It’s bad for you, really! with torch.no_grad(): will do fine.

In fact, you probably should copy back the original values between the backward and the step. For a reasonably complicated model (more than one layer), your gradients might be off from this and if you had stayed clear of using .data PyTorch would have told you.

You multiply with 1- drop_prob, which seems unusual.

The convenience of a wrapper: In any single instance is easily done manually, but now you would want to apply this to some weights rather than all. Is it still as straightforward?

The safety of a wrapper: Having a well-tested wrapper saves you from implementation mistakes. (See above.)

Best regards

Thomas

1 Like

I’m forced to use p.data, because if I replace p.data with p, the weights won’t be modified. For example, if you print p values before and after the dropout line, you will see that only p.data method works (zeros out some weights). Unless you mean something else?

Good point about moving weight restore after backward call! Thank you.

I have to multiply by 1 - drop_prob because dropout scales its input internally by 1 / (1 - drop_prob), and if I don’t do this the accuracy drops sharply: with drop_prob=0.05 it does not even converge if I don’t scale back the weights. I’m not sure what’s going on, but I suspect it might have something to do with batchnorm. Any ideas?

What do you mean “is it still straightforward?” With my method it’s much easier to apply dropconnect selectively, I don’t have to create wrappers for every single layer type, and I don’t have to modify my model forward function. I agree with your point about safety.

@tom, should open a bug about p not being zeroed out when used in dropout call?

No. p = something just doesn’t overwrite elements of p, but instead assigns a new thing to the name p. That’s inherent in how Python works, you want p.copy_(...).
There are extremely few reasons to use p.data, and chances are you’re doing it wrong if you’re using it. (And people are getting serious about removing it properly, so hopefully it’ll go away soon.)

For the scaling, I don’t know. From a cursory look at the Gal and Ghahramani paper, maybe they also use the plain Bernoulli. I’d probably multiply with torch.bernoulli(weight, 1-drop_prob) instead of using dropout and scaling.

Best regards

Thomas

Ok, it makes sense, I replaced all “p.data =” with “p.copy_” and added no_grad() context. No difference in performance that I can see, but if it’s safer, so be it.

I ran a few experiments with scaling, and yes, it seems like scaling is necessary, otherwise batchnorm will screw things up during inference. Recomputing batch statistics during inference also fixes the issue (same good accuracy with or without scaling), but that’s obviously not a solution. Not sure how to use torch.bernoulli did you mean binomial? Tried generating binomial masks, but I don’t see a good way to generate them quickly on GPU. I could only do mask = binomial_distr.sample(p.size).cuda() and this is very slow.

I’ll have to admit that the bernoulli API is more than a little bit awkward (it should be a factory function just like randn, really), but

mask = torch.bernoulli(torch.randn(1, 1, device='cuda').expand(*weight.shape), 1-drop_prob)

should work.

Best regards

Thomas

Hi, I have an irrelevant question: is .data equivalent to .detach()?

No .data is more risky because it also skips the versioning count. It has long been deprecated. The remove .data issue has great technical discussion.

Best regards

Thomas

1 Like

I believe this post may have a point about proposed dropconnect implementations being wrong (?), when relying on ‘dropout’ method