Defining a module parameter as a function of other module parameters

Hello! I am currently working on a custom module inside my pytorch model. This module creates a gaussian kernel that is convolved across the models input, while the parameters of the gaussian kernel are to be learned. There is to be 256 x 256 gaussian kernels generated. My init function looks like this:

    def __init__(self, channel_num = 512, size = 32, ksizeDN = 12):
        super().__init__()
        self.channel_num = channel_num
        self.size = size
        self.ksizeDN = ksizeDN

        self.thetaD = torch.nn.Parameter(uniform.Uniform(0, torch.pi).sample([self.channel_num, self.channel_num]),
                                    requires_grad=True)

        self.p = torch.nn.Parameter(uniform.Uniform(2, 6).sample([self.channel_num, self.channel_num]),
                                    requires_grad=True)

        self.sig = torch.nn.Parameter(uniform.Uniform(2, 6).sample([self.channel_num, self.channel_num]),
                                      requires_grad=True)

        self.a = torch.nn.Parameter(
            torch.abs(torch.randn(self.channel_num, self.channel_num, requires_grad=True)))
        #self.nU = torch.nn.Parameter(torch.abs(torch.randn(1, self.channel_num, 1, 1, requires_grad=True)))

        self.gaussian_bank = torch.nn.Parameter(torch.zeros(self.channel_num, self.channel_num, self.ksizeDN * 2 + 1,
                                                            self.ksizeDN * 2 + 1), requires_grad=False)
        self.x = torch.linspace(-self.ksizeDN, self.ksizeDN, self.ksizeDN * 2 + 1)
        self.y = torch.linspace(-self.ksizeDN, self.ksizeDN, self.ksizeDN * 2 + 1)
        self.xv, self.yv = torch.meshgrid(self.x, self.y)

        for i in range(self.channel_num):
            for u in range(self.channel_num):
                self.gaussian_bank[i, u, :, :] = self.get_gaussian(i, u)

The first four parameters are variables in the gaussian kernel equation (self.get_gaussian), which is called when constructing the ‘gaussian bank’ in the for loop. The gaussian bank is 256 x 256 x 25 x 25. For reference, the gaussian kernel equation is defined as:

    def get_gaussian(self, cc, oc):

        xrot = (self.xv * torch.cos(self.thetaD[cc, oc]) + self.yv * torch.sin(self.thetaD[cc, oc]))
        yrot = (-self.xv * torch.sin(self.thetaD[cc, oc]) + self.yv * torch.cos(self.thetaD[cc, oc]))
        g_kernel = torch.tensor((abs(self.a[cc, oc]) /
                                 (2 * torch.pi * self.p[cc, oc] * self.sig[cc, oc])) * \
                                torch.exp(-0.5 * ((((xrot) ** 2) / self.p[cc, oc] ** 2) +
                                                  (((yrot) ** 2) / self.sig[cc, oc] ** 2))))

        return g_kernel

After running the training on this, I notice that even though the parameters are changing, the ‘gaussian bank’ parameter never changes. In fact, all the parameters (thetaD, a, sig and p) all seem to go to zero over multiple epochs, which does not seem right. How can I define the gaussian bank parameter so that the model updates the gaussian bank whenever the parameters are adjusted? I am a bit new to pytorch, so apologies for any confusion - I am willing to answer anything I missed and I am very appreciative for any help!

Hi Andrew!

The parameters that are to be learned should be the Parameters of
your Module. gaussian_bank should just be a regular variable – not
a Parameter – because it is not being “learned” (optimized) directly.

(I assume that gaussian_bank is what you are referring to as your
“gaussian kernel.”)

I don’t think I would even have gaussian_bank be a “property” of your
Module. (That is, I would not have it be self.gaussian_bank, but rather
just a “free” variable in your Module's forward() method.)

Something seems fishy here.

You say that thetaD, a, sig, and p – the parameters from which you
derive gaussian_bank – are changing, specifically, they “go to zero.”
However, gaussian_bank “never changes.”

Looking superficially at your code, I do see:

self.gaussian_bank[i, u, :, :] = self.get_gaussian(i, u)

So you are modifying self.gaussian_bank (inplace, because you are
indexing into it). Therefore, unless get_gaussian() is returning a constant
for some reason, gaussian_bank really should be changing, reflecting the
changes in its underlying parameters (such as thetaD).

Best.

K. Frank

Great, it looks like this might be leading me in the right direction. One other thing I was wondering - it seems like weight decay may not be appropriate to use on these parameters as it seems to punish them too much, which is possibly why these parameters kept shooting down to zero.

Do you know any way that I can apply weight decay to specific modules and not to others? For example, I have a resnet backend that I want to apply weight decay to, but I dont want to apply it in this module. Thanks for the help so far!

Hi Andrew!

My intuition would lead me to agree with you on this. The way I look at
it, there is a lot of redundancy in the weights of a typical network. There
are lots of individual scalar parameters, many of which can trade off against
one another – weight A and weight B might both be able to play the same
role as one another, so you don’t need both (but you don’t know ahead of
time which you could leave out).

Or weight A and weight B could compensate for one another so that
weight A could run off to +inf while weight B runs off to -inf. Weight
decay (and other regularization techniques) serve to “tamp down” this
redundancy.

But your gaussian-kernel parameters don’t seem to me to be like this.
Each one has a well-defined meaning, they don’t really overlap with
one another, and you can’t leave one out without restricting the structure
of your gaussian kernel.

So not applying weight decay to your gaussian-kernel parameters makes
sense to me. (You could also just apply much weaker weight decay to them
than to your “regular” network parameters, but I would lean towards no
weight decay at all.)

Yes. Use the parameter-group feature of pytorch’s Optimizers.

Best.

K. Frank

Awesome, thanks for all the help!