Is the PyTorch NLP DropConnect implementation correct?

The implementation for basic Weight Drop in the PyTorch NLP source code is as follows:

def _weight_drop(module, weights, dropout):
    """
    Helper for `WeightDrop`.
    """

    for name_w in weights:
        w = getattr(module, name_w)
        del module._parameters[name_w]
        module.register_parameter(name_w + '_raw', Parameter(w))

    original_module_forward = module.forward

    def forward(*args, **kwargs):
        for name_w in weights:
            raw_w = getattr(module, name_w + '_raw')
            w = torch.nn.functional.dropout(raw_w, p=dropout, training=module.training)
            setattr(module, name_w, w)

        return original_module_forward(*args, **kwargs)

    setattr(module, 'forward', forward)

This method, clearly, uses the dropout function available in torch.nn.functional to perform the dropping of the weights.

I wasn’t able to find the actual implementation of that dropout function, but I assume it is correct as it is widely used. In the Dropout Paper the authors mention differences in the two phases of training and testing. During training, they drop the activations randomly, which I’m sure the dropout implementation has worked out correctly, and during testing, they multiply the weights in the connecting layer, by p, so that the expected value of the activations is same in the next layer as it was in training, which makes sense. I assume that Dropout does exactly that when it takes ‘module.training’ as an argument, to decide what to do.

For DropConnect, however, the authors mention the following algorithms for training and testing:
image
image

Basically, the training part is the same as Dropout, in that here the weights are dropped, just like Dropout. But, during inference, the process changes. The DropConnect paper describes that ‘averaging’ the values by multiplying the weights with p in case of Dropout, is not justified mathematically as the averaging is done before applying the activation functions.

So, for DropConnect paper does this: take the values of the activations, just before the DropConnect weights. Finds the mean and variance of the next layer values, based on W,p and the inputs. Draws Z (could be a hyperparameter as far as I understand) samples from the distribution, assuming it is a Gaussian, to create lots (Z) of possible values. Applies activations on all of those values, and then finally average over Z to get the input for the next weights/layer (softmax in their case).

Since, the original Dropout doesn’t consider this, and if that is the way the implementation for PyTorch Dropout is, then essentially, the DropConnect implementation linked above must be ‘wrong’ (not what they explained in the paper).

I am not sure why DropConnect hasn’t gained much traction here or in DL research, but can someone please explain whether the above implementation is right or wrong?
And, in case it is indeed wrong, what can I do to implement DropConnect correctly?

1 Like

Both DropConnect in EfficientNet and PyTorch Dropout seem to divide by (1 - p) to help maintaining the mean (although this isn’t true if there are negative values, e.g. after batchnorm, or swish)
respectively:

The difference between dropoutNd and dropconnect (as in efficientnet) is that the former drops some channels for each batch sample (1st and 2nd tensor dimensions, batch and channel), but dropconnect drops whole samples all together (only the 1st batch dimension).

So as a late answer, I don’t think the NLP implementation above is equivalent to any of two here. I believe dropout1d would be an approximation to the paper’s algorithm, as it tries to maintain the statistics between training and inference, and drops whole feature channels of each sample.

Is the EfficientNet implementation of DropConnect correct? It seems like they use this function here, but if you apply this function to the x (activation) then it is not DropConnect, but rather just simple Dropout. Let me know if i’m missing something…

Dropout2d drops random channels from each image, their drop_connect drops whole images from the batch.
Therefore, every image is only seen by a subset of layers.
I am not sure what is the ‘correct’ way of doing DropConnect, but it seems to work reasonably well.

Maybe this is a different type of DropConnect that they are referencing. I was under the impression that DropConnect referred to dropping the weights of a layer during the forward pass, as described in this paper. As far as I understand, there would be no way to implement this by using the activations of a layer’s output - it would only work by masking the weights of the layer itself.

Yeah, it seems to be a confusion with names here. I believe efficientnet’s drop_connect() method linked above is their implementation of stochastic depth