Network initialization with nn.module

I am trying to use an inner class as my nn but it seems that the superclass does not recognize the name. Does anybody have any solutions for this?

this is the snippet of code that is getting an error:

class FashionMNISTmodel:
    class FashionNet(nn.Module):
        def __init__(self, activationfcn):
            super(FashionNet,self).__init__()

This is where I initialize the inner class:

        self.net = self.FashionNet(activationfcn = self.actfcn).to(device)

this is the error i’m getting:
NameError: name ‘FashionNet’ is not defined

This works for me:

class FashionMNISTmodel:  
    def __init__(self):    
        self.net = self.FashionNet(torch.relu)
        
    class FashionNet(nn.Module):
        def __init__(self, activationfcn):
            super().__init__()
    
    
model = FashionMNISTmodel()
print(model)
print(model.net)
1 Like