What is the difference between defining a function in pytorch and defining a class that inherits nn.Moudle

I’m having some trouble writing transfoemer code in the computational attention module,
I can define a class to compute attention in forward like this:

class CalculateAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,q, k, v, mask=None, e=1e-12):
     ........

I can also define a function directly to calculate attention:

def attention(q, k, v, d_k, mask=None, dropout=None):
    ............
    return output

I don’t know what the difference is between the two ways, and which way should I use it?

if you have state (such as weights and other buffers) that is alive in between computations, write a Module – because you can keep this state as a member of the Module (like self.weight).
If not and you are writing stateless logic, write a function.

1 Like