Custom layer weights do not move to 'mps' device

Hello,

I am having trouble tracing why the model.to(‘device’) is not moving my whole model to the mps device on a M1 macbook pro.

I am creating custom layer blocks to use within an encoder decoder setting as following:

class Conv2dLayer(nn.Module):
    def __init__(self, nin, nout, kernel_size, stride_sz, pad_sz,
                 activation='relu'):
        super(Conv2dLayer, self).__init__()
        self.activation = activation
        self.conv2d = nn.Conv2d(nin, nout, kernel_size, stride_sz, pad_sz)
        if self.activation == 'relu':
            self.activ = nn.ReLU(inplace=True)
        elif self.activation == 'lrelu':
            self.activ = nn.LeakyReLU(0.2, inplace=True)
        self.bn = nn.BatchNorm2d(nout)
        
    def forward(self, input):
        x = self.conv2d(input)
        x = self.activ(x)
        x = self.bn(x)
        return x

I have a similar layer with ConvTranspose2d. Those layers are then used to populate a conv encoder and deconv decoder which are then used for the autoencoder class. The encoder is as following:

class ConvEncoder(nn.Module):

    def __init__(self, encoder_dict):
        super(ConvEncoder, self).__init__()
        self.in_channels = encoder_dict["n_channels"]
        .....

        self.c1 = Conv2dLayer(self.in_channels, self.layer_nodes[0], kernel_size=(4,4), 
                              stride_sz=(3,3), pad_sz=(1,1), activation=self.activation)
        .....
    
    def forward(self, x):

        out1 = self.c1(x)
        .....
        return out

And the final model is something like that:

class CNN_AE(nn.Module):

    def __init__(self, encoder_dict, decoder_dict):
        super(CNN_AE, self).__init__()
        self.encoder = ConvEncoder(encoder_dict)
        self.decoder = ConvDecoder(decoder_dict)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)

        return x

I am getting the following error:

RuntimeError: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same

which is traced to the definition of my custom layer forward line. Now when I am printing all the named parameters of the original model I get an expected result that device is mps:0. However, when I print the named parameters of the first layer within the encoder like so:

        for name, param in self.c1.named_parameters():
            print(f"{name} device: {param.device}")

I get the following result:

conv2d.weight device: cpu
conv2d.bias device: cpu
bn.weight device: cpu
bn.bias device: cpu

While the situation is similar to this Why model.to(device) wouldn't put tensors on a custom layer to the same device? , in that question the op wanted to move a newly created tensor from within the model to the device which was solved by registering buffers. Here, I can’t trace what is going on. At the same time I am worried that the custom layers are not registered to the grad chain therefore may not be updated even if I try and train with the cpu. The only way to go around this currently is to move everything to device from within the forward function of the encoder or the custom layer definition which is a hacky work-around.
Any thoughts?

1 Like

hey!

One thing to be careful about is that only modules that are set directly on their parent are detected. So if you do mod.child = [child1, child2] then the childs won’t be detected. You can put them in a ModuleList to fix it.
A full example would be necessary to find it.

1 Like

Hello, thank you for the reply! Could you kindly explain how would that work (or point me to a tutorial/documentation) in case I want to set up models with many different custom layers?

Say I have layer classes like following:

class LinearLayer(nn.Module):
    def __init__(self, nin, nout, activation='relu', dropout_perc=0.1):
        super(LinearLayer, self).__init__()
        self.activation = activation
        self.fc = nn.Linear(nin, nout)
        if self.activation == 'relu':
            self.activ = nn.ReLU(inplace=True)
        elif self.activation == 'lrelu':
            self.activ = nn.LeakyReLU(0.2, inplace=True)
        elif self.activation == 'tanh':
            self.activ = nn.Tanh()
        self.dropout = nn.Dropout(dropout_perc)
        
    def forward(self, input):
        x = self.fc(input)
        x = self.activ(x)
        x = self.dropout(x)
        return x

If I wanted to create a network that encompasses many instances of this layer class and others like an LSTM/CNN/Attn layer, what is the proper way to do that so that everything is registered?
Would creating a module list and appending layers to it work ok? For instance take the following:

class Encoder(nn.Module):
    def __init__(self,*args):
         super(self).__init__()
         arg1 = ....
         self.network = nn.ModuleList([])
         for i in range(len(layers):
             self.network.append(LinearLayer(in_dim,out_dim,)
         self.lstm = nn.LSTM(input_size=in_size, hidden_size=hid_size)
    
    def forward(self,x):
         x = self.network(x)
         x = self.lstm(x)
         return x

Following if I wanted to create a network with many encoders one appended to the other, does the same reasoning stand? Meaning, create a module list and then append networks that are designed as module lists? If you could create a dummy example with the above dummy code that would be awesome.

Thanks in advance!

Your example above will work fine yes. Appending to the ModuleList will do what you expect.
You can check that moving such a Module to mps will work fine.
If you change self.network = [], the rest of the inittialization code will run fine but you’ll see that the move to mps will not work anymore.

1 Like