Reproduce same network architecture (with dropout active) in multiple steps: zeroes multiple times the same elements with dropout

Hi everyone,
I have a network architecture with Dropout layers.
From the above official reference:

During training, randomly zeroes some of the elements of the input tensor with probability p using samples from a Bernoulli distribution.

I need to fix the same configuration for multiples consecutive N steps; when I say fix the same configuration I mean:

  • at step 0, randomly zeroes some of the elements of the input tensor with probability p using samples from a Bernoulli distribution.
  • In the following N steps, zeroes the same input’s elements chosen in step 0 (so the network seen by the system for each of these steps is exactly the same)
  • At step N+1 we restart from the beginning (randomly zeroes some of the elements of the input tensor with probability p using samples from a Bernoulli distribution.)
  • repeat

How can I do this?

[EDIT]

There are some example of problem similar to the above (example_1, Example_2, Example_3).
Anyhow I still have some doubt about how to proceed:

  • In my network I have something like this:
cfg = {

    'VGG16': [64, 'Dp', 64, 'M', 128, 'Dp', 128, 'M', 256, 'Dp', 256, 'Dp', 256, 'M', 512,'Dp', 512,'Dp', 512, 'M', 512,'Dp', 512,'Dp', 512, 'A', 'Dp'], #dropouts dependent from a single parameter (useful for hyper-par optim.) 
    
}


class VGG(nn.Module, NetVariables, OrthoInit):
    def __init__(self, params):
        

        self.params = params.copy()
        

        nn.Module.__init__(self)
        NetVariables.__init__(self, self.params)
        OrthoInit.__init__(self)
        
     
        self.features = self._make_layers(cfg['VGG16'])
        self.classifier = nn.Linear(512, self.num_classes)



        self.weights_init() #call the orthogonal initial condition
        


    def forward(self, x):
        outs = {} 
        L2 = self.features(x)
        outs['l2'] = L2
        Out = L2.view(L2.size(0), -1)
        Out = self.classifier(Out)
        outs['out'] = Out
        return outs

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                
            elif x=='A':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
                
            elif x == 'D3':
                layers += [nn.Dropout(0.3)]

            elif x == 'D4':
                layers += [nn.Dropout(0.4)]            

            elif x == 'D5':
                layers += [nn.Dropout(0.5)]   
                
            elif x == 'Dp':
                layers += [nn.Dropout(self.params['dropout_p'])] 
                
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),

                           nn.Tanh()
                           ,nn.GroupNorm(int(x/self.params['group_factor']), x)
                           ]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers) 

How can I adapt the code in such a way:

  1. Substitute Dropout in nn.Sequential by a layer that is nn.Identity when model.training()==False (evaluation mode) and is a mask layer when model.training()==False (training mode)

  2. the mask mentioned in the point 1. keeps constant and changes only when triggered by a given external input flag

  1. Either use the proposed approach from your cross post or initialize the layers first and use a switch in the custom forward method.

  2. I would write a custom dropout layer, which accepts an additional flag to resample the mask.

1 Like

Thank you @ptrblck ! your advices were very helpful.
I report here the explicit code to make it available for users who will need it.
If anyone notice something that seems wrong please don’t esitate to report it.

class VGG_Custom_Dropout(nn.Module, NetVariables, OrthoInit):
    def __init__(self, params):
      
        self.params = params.copy()
        
        nn.Module.__init__(self)
        NetVariables.__init__(self, self.params)
        OrthoInit.__init__(self)
        
        
        self.ModuleDict = self._make_layers(cfg['VGG16']) #CREATE A DICT OF MODULES WITH IDENTITY ON DROPOUT PLACE
        self.classifier = nn.Linear(512, self.num_classes)
        self.mask_dict = {}


        self.weights_init() #call the orthogonal initial condition
        


    def forward(self, x, Mask_Flag, TrainMode_Flag):
        outs = {} 
        #WE ITERATE OVER THE MODULE DICT SUBSTITUTING IDENTITY WITH THE DROPOUT MASK WHEN THE TrainingMode_Flag TELL US WE ARE IN TRAINING MODE
        
        for key in self.ModuleDict:
            if key.startswith('!'):
                
                if not TrainMode_Flag: #if we are in eval mode the dropout is substitute by a identity layer
                    x = self.ModuleDict[key](x)
                else:
                
                
                    if Mask_Flag==1: #this flag trigger the update of masks 
                        self.mask_dict[key] = torch.distributions.Bernoulli(probs=(1-self.params['dropout_p'])).sample(x.size())

                    x = x * self.mask_dict[key] * 1/(1-self.params['dropout_p']) #dropout layer

                    
            else: #for modules different from dropout regular forward
                x = self.ModuleDict[key](x)
        
        L2 = x
        outs['l2'] = L2
        Out = L2.view(L2.size(0), -1)
        Out = self.classifier(Out)
        outs['out'] = Out
        return outs

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        ModuleDict = nn.ModuleDict()
        NumberKey=0
        
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                
            elif x=='A':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
                 
                
            elif x == 'Dp':
                
                ModuleDict[str(NumberKey)] = nn.Sequential(*copy.deepcopy(layers)) 
                NumberKey+=1
                ModuleDict['!'+str(NumberKey)] = nn.Identity()
                NumberKey+=1
                layers = []
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           #nn.ReLU(inplace=True),
                           #nn.LeakyReLU(negative_slope=0.1, inplace=False),
                           nn.Tanh()
                           #,nn.BatchNorm2d(x)   #For now Batch Norm is excluded because it is incompatible with PCNGD, GD, PCNSGD where I forward sample by sample
                           ,nn.GroupNorm(int(x/self.params['group_factor']), x)
                           #,nn.GroupNorm(int(1), x)
                           ]
                in_channels = x
        NumberKey+=1
        ModuleDict[str(NumberKey)] = nn.AvgPool2d(kernel_size=1, stride=1)
        return ModuleDict

In this way I have control on where the mask has to change and when it has to be substitued by an identity layer (when we are in the evaluation mode). Do I miss something? For example should I detach the mask multiplication?
I think it should not make difference because also if the gradient record that operation I’m turning off the nodes, so no signal leave from them. Is it right?

You could replace TrainMode_Flag with self.training, as the latter flag is changed by calling model.train() and model.eval(). Besides that I don’t see any obvious issues.

1 Like