Beginner: Should ReLU/sigmoid be called in the __init__ method?

I am trying to rebuild a Keras architecture in pytorch, which looks like this

    rnn_layer1 = GRU(25) (emb_seq_title_description)
    # [...]
    main_l = Dropout(0.1)(Dense(512,activation='relu') (main_l))
    main_l = Dropout(0.1)(Dense(64,activation='relu') (main_l))
    
    #output
    output = Dense(1,activation="sigmoid") (main_l)

So I tried to adjust the basic RNN example in pytorch and add ReLUs to the Linear layers. However, I am not sure if I can call ReLU directly in the forward method or should call it in the init method.

My first try looks like this:

import torch.nn as nn

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)


    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = nn.ReLU(self.i2h(combined))
        output = nn.ReLU(self.i2o(combined))
        output = nn.sigmoid(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

n_hidden = 128
n_letters = 26
n_categories = 2
rnn = RNN(n_letters, n_hidden, n_categories)

However, when I look at the rnn object in python, I will not see the ReLUs, so maybe it’s not right to call nn.ReLU directly in the forward method…

RNN(
  (i2h): Linear(in_features=154, out_features=128, bias=True)
  (i2o): Linear(in_features=154, out_features=2, bias=True)
)
9 Likes

Since nn.ReLU is a class, you have to instantiate it first. This can be done in the __init__ method or if you would like in the forward as:

hidden = nn.ReLU()(self.i2h(combined))

However, I would create an instance in __init__ and just call it in the forward method.

Alternatively, you don’t have to create an instance, because it’s stateless, and could directly use the functional API in forward:

hidden = F.relu(...)
20 Likes

let me repeat the question differently: what is the PyTorch-idiomatic way to use the relu() and WHY? I think the answer is to use F.relu() in the forward() function. The WHY part is important here and I’d love to hear a full answer. Hope I’m not complicating things more than necessary.

2 Likes

My personal preference is to use the functional API in the forward for stateless objects, e.g. F.relu.
Since nn.ReLU doesn’t store any parameters it’s not really necessary to define it using the module way.

However, there is one exception when I prefer the module init and that’s when I know I will try out different activation functions in the forward pass.
So instead of changing the F.relu to another non-linearity repeatedly, I use a definition like in this example:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
        self.act = nn.ReLU()
        
    def forward(self, x):
        x = self.act(self.conv1(x))
        return x

and just switch self.act for another module.

I think it just comes down to your personal preference, since I think neither of both coding approaches is bad in any way.

15 Likes

my point isn’t really about coding style per se, but more about correct PyTorch semantics. The motivation for my followup question is more like the o.p. : What goes into init() and what goes into forward()? as I found that confusing as well, especially coming from Keras. I think it boils down to wether a layer (module??) holds state or not as you put it. Maybe there are other reasons I have not learned yet??. For example, you can even add new layers in the forward(), e.g. F.conv2d(…), but that layer weights are lost once the forward() exits, so it can’t be used as part of the trained model for example at prediction time. So the semantics -not the syntax- of what goes into init() vs what [should] go into forward() is the main point of the question.

3 Likes

OK, I see.
Let me try to boil it down to the standard approach and possible advantages of changing this approach.

In the standard use case, you are writing some kind of model with some layers. The layers hold most likely some parameters which should be trained. nn.Conv would be an example. On the other side some layers don’t have any trainable parameters like nn.MaxPool. However, usually also these are created in __init__. Other “layers” don’t have parameters and can also be seen as simple functions instead of a proper layer (at least in my opinion) like nn.ReLU.
In your forward method you are creating the logic of your forward (and thus also backward) pass.
In a very simple use case you would just call all created layers one by one passing the output of one layer to the other. However, since the computation graph is created dynamically, you can also create some crazy forward passes, e.g. using a random number to select a repetition of some layers, split your input into different parts and call different “sub modules” with each part, use some statistics of your input to select a specific part of your model, etc. You are not bound to any convention during this as long as all shapes etc. match.
For me this is one of the most beautiful parts of PyTorch. Basically you can let your imagination flow without worrying too much about some limited API which can only call layers in a sequential manner.
And this is also the reason to break some of these conventions I’ve mentioned before.

Think about a specific use case where you would like to use a conv layer, but for whatever reason you need to access and maybe manipulate its weight often. The first approach of creating the layer in __init__ and applying it in forward would certainly work. However, the weight access might be a bit cumbersome.
So how about we just store the filter weights as nn.Parameters in __init__ and just use the functional API (F.conv2d) in the forward method. Would that work at all? Sure! Since you’ve properly registered the filter weights in __init__, they will be trained as long as they are used somewhere in the computation during your forward pass.
As you can see these are somewhat advanced use cases and I wouldn’t say they are breaking some kind of PyTorch semantics. Using the functional API is totally fine for advanced use cases. I would not recommend to use the functional API for every layer from now on. It’s much easier to use nn.Modules in most use cases.

Have a look at the implementations of torchvision.models.
You will see all kind of different coding styles depending on the complexity of the problem. While simpler models might be implemented using some nn.Sequential blocks in __init__ and just calling them in forward (e.g. AlexNet, other models will be implemented in a different manner using more functional calls (e.g. Inception since it’s a bit more complicated to split and merge the activations as well as getting the aux loss).

To sum it up, my two cents are: use whatever feels good and easy for you. Although there are some “standard” approaches for some use cases, I have to say that even after working with PyTorch for a while now, I probably change the one or other coding style every few weeks (because I suddenly have the feeling the code logic is easier to follow using this new approach :wink: ). If you ask 10 devs to implement Inception, you’ll probably get 10 different but all beautiful and useful implementations.

90 Likes

Thank you for the time and effort in providing a great answer. Now that I understand your advice more deeply I think I gained a lot more intuition into PyTorch. Much appreciated.

3 Likes

What about the following use case: I need a convolution layer (therefore learnable, unlike nn.MaxPool which you already mentionned) where the kernel_size parameter is the shape [height, width] of the input tensor, which is unknown at __init__ time. This results in a pooling effect, so that the output tensor has a NxCx1x1 shape.

Is it still possible to implement this, using the __init__ and forward behaviour?

(Thinking out loud: could this be replaced by a generic convolution in the __init__ and a global pooling layer in the forward method?)

If you have static input shapes, that would be possible, but that would also be kind of the “standard approach” initializing the conv layer just a bit later.
I guess you are dealing with variable sized inputs and would like to use the corresponding conv kernel “on the fly”.
In that case it’s a bit trickier, since there are probably some approaches, which would have some shortcomings.

  1. You could pass your new input to your model and the shape would be checked in forward. If the corresponding kernel was already created, just use it for processing the input and training. Else create this kernel and do the same.
    Depending on the variety of your input shapes this approach would create a lot of kernels, which are maybe hardly being trained. E.g. if a specific size occurs only once in your training data, this kernel will just be used once in the training. Also what should happen, if you’ll get new shapes in your validation/test dataset?

  2. Use a single kernel and manipulate it, i.e. increase/decrease somehow its spatial size. In that case I’m not sure how the kernel will behave if you have large differences between the input sizes. The filter will not only see the features in a completely other resolution, but also if you just append some random values to the filter or cut the filter, the weights will have some “discrepancy”, e.g. while the center was already trained, the borders are completely random weights. I’m not sure, if it would be a good idea to interpolate the kernel weights to another size based on the input size, but that seems to me like a valid approach to try out.

If you just need global pooling, have a look at the adaptive pooling layers.

2 Likes

Thanks for the very detailed answer :slight_smile:

Actually, it’s more of a “convenience” hack than anything, since the inputs have static size. However, the architecture of the network might change, and the shape would be dependent on these changes, since the “pooling-conv” layer is just before the last layer (FC). Also, it might be required to use the network with a database of different dimensions.

Having this shape computed “on the fly” would save me from having to calculate it manually (i.e. by failing a forward pass and checking the dimensions, or actually thinking about convolution padding, strides and their effect on the output shape) when either change occurs.

I’ve settled on the following approach for now, since the inputs have static size: run a partial forward on a tensor (optional argument) in the __init__, get the resulting size, and initialize the “pooling-conv” layer accordingly. This way I can recompute the value as necessary.

Some trial and error is still required, but that will do for the moment!

Could I use the nn.Con2d Module directly in forward function with out stating in init ? e.g.

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.act = nn.ReLU()
        
    def forward(self, x):
        x= nn.Conv2d(3, 6, 3, 1, 1)(x)
        x = self.act(x)
        return x

Could this still be learnable while training?

You would reinitialize the module including its parameters in each forward pass, so that you would at most update them once then overwrite them in the next forward pass.

3 Likes

Oh, Thanks!
so the learnable parameters should be held in the _init_ part within MyModel. So during the forward and backward process, these parameters could be updated progressively. So in this way, MyModel can be either written in this way:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv_like = nn.convlike() # which has its own hidden parameters 
        
    def forward(self, x):
        x = self.conv_like(x)

or written in this way:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.func_params = params # which will be used in nn.functional.funs 
        
    def forward(self, x):
        x = nn.functional.funs(x,self.func_params)

other nn.funs or nn.functional.funs which do not have hidden learnable parameters could be placed wherever you like.
is it correct?

6 Likes

Yes, your explanations and assumptions are correct.

3 Likes

Thankyou very much for this use case. I in essence understand what is being illustrated in these lines. But because i am still a novice in coding, can a very quick code be given so that i can understand how a weights of conv can be set as parameter in init and then using F.conv they will be updated. Thankyou for this seemingly redundant need but helps me a lot.

This code shows how to initialize a conv in the __init__ method and use its parameters in the forward:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        conv = nn.Conv2d(1, 6, 3)
        self.weight = conv.weight
        self.bias = conv.bias
        
    def forward(self, x):
        x = F.conv2d(x, self.weight, self.bias)
        return x

x = torch.randn(1, 1, 4, 4)
model = MyModel()
out = model(x)

You would have to add potential padding via F.pad, if it’s used in your conv layer.

4 Likes

Thankyou very much. So if I modify the code as follows,


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        conv = nn.Conv2d(1, 6, 3)
        self.weight = conv.weight
        self.bias = conv.bias

    def modify_conv_weights(self, weights, bias):
        # DO SOME OPERATIONS
        return weights, bias
        
    def forward(self, x):
        x = F.conv2d(x, self.weight, self.bias)
       self.weight, self.bias = modify_conv_weights(self.weight, self.bias)
        return x

x = torch.randn(1, 1, 4, 4)
model = MyModel()
out = model(x)

Then this way I will be able to modify the the weights of the conv layer or in each forward pass of the model the new weights will have no correspondence to previous value?

Yours sincerely

1 Like

Depending on the operations in modify_conv_weights, you might break the computation graph and just reassign new parameters, so that the weight and bias would not be trained by the optimizer.

1 Like

Oh I see. If the operation is a function which operates on incoming weights say taking log of the weights or say normalising them such as L2 norm is 1, I hope that does not break the computation graph?

1 Like

These operations should create a non-leaf parameter, and the .grad attribute would disappear.
In the end you’ve replaced the original parameter which was used to compute the output activation, so that the backward pass won’t be able to compute the valid gradients anymore, no?

Could you explain your use case a bit?

1 Like