Device mismatch error with nn.DataParallel

So I was able to resolve my issue based on Omkar’s suggestion to look at https://github.com/pytorch/pytorch/issues/8637

I changed the following: Instead of binding forward to the only_forward method, I directly call it inside the forward method:

        if add_output:
            self.output = af.InternalClassifier(input_size, self.expansion*channels, num_classes) 
            self.no_output = False
        else:
            self.output = None
            #self.forward = self.only_forward
            self.no_output = True
            
    def forward(self, x):

        if self.no_output:
            return self.only_forward(x)        
        else:
            fwd = self.layers[0](x) # conv layers
            fwd = fwd + self.layers[1](x) # shortcut
            return self.layers[2](fwd), 1, self.output(fwd)         # output layers for this module

    
    def only_output(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        fwd = self.layers[2](fwd) # activation
        out = self.output(fwd)         # output layers for this module
        return out
    
    def only_forward(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        return self.layers[2](fwd), 0, None # activation
1 Like