Is it okay to reuse activation function modules in the network architecture?

Does it make any discernible difference to a model whether activation function modules are reused within a neural network model?

Specifically, is it expected that training results differ depending on whether you reuse such modules or not?

Example model without reusing ReLU’s:

class NormalModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = conv_block(3, 64)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = conv_block(64, 1)

    def conv_block(self, chin, chout):
        return nn.Sequential(
            nn.Conv2d(in_channels=chin, out_channels=chout, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(chout),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        return x

Note above that two separate nn.ReLU objects are instantiated, each stored only once in the NormalModel class, and each applied to data only once in the forward pass.

Example model with reusing ReLU’s:

class ReusedModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.act_func = nn.ReLU(inplace=True)
        self.conv1 = conv_block(3, 64)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = conv_block(64, 1)

    def conv_block(self, chin, chout):
        return nn.Sequential(
            nn.Conv2d(in_channels=chin, out_channels=chout, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(chout),
            self.act_func,
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        return x

Note above that the single instantiated nn.ReLU object is stored three times within the ReusedModel class, and applied twice to different data (of different sizes) during the forward pass. The naive rationale of why this shouldn’t matter is that nn.ReLU internally just calls F.relu() anyway, and doesn’t store anything else in the class, so how can it make a difference?

With a more complicated model than the example one above, I repeatably obtain worse training performance if I reuse the ReLUs, as compared to doing things the ‘normal’ way. I have tested it many many times now but it keeps coming out the same. Why could this be?

PyTorch:
Version = 1.4.0a0+7f73f1d
Git commit = 7f73f1d591afba823daa4a99a939217fb54d7688
Compiled with CUDA version = 10.1
Compiled with cuDNN version = 7.6.5
Compiled with NCCL version = 2.4.8
OpenMP available = True
MKL available = True
MKL-DNN available = True

2 Likes

This shouldn’t matter and it’s interesting that you see a difference.
How large is the difference in the final accuracy and how reproducible is this effect?

Is it possible that backprop thorugh the same ReLU layers could be a problem in more complex architectures?

Although ReLU does not have learnable parameters, shouldnt it still affect the backprop in a different way if we reused the same ReLU

No, it shouldn’t as ReLU is just calling into a stateless function (max(0, x)).
It would be comparable to reusing a multiplication, which also shouldn’t change the outcome of a model.

Thank you for the clarification.

Please note, that this shouldn’t dismiss the original question, and we would need to look into this behavior, if the issue is reproducible. :wink:

I thought it would maybe cause a problem because of the ReLU registering multiple times with the model, and then maybe having certain operations happen on it too often or too little during backprop - I really don’t know because I don’t know details of how nn.module/autograd/etc is implemented.

I had a model that was working well so I cleaned up the code and added a few extra options. A few days later I ran more tests and started noticing that the training was more ‘chaotic’ (the losses would behave more erratically during training, and less consistently towards the target), and the training results never matched the best results I had seen until then, despite many attempts. I was surprised and kept digging until I noticed that changing purely the reuse of ReLUs seemed to make the difference.

In the meantime I must have trained about 50 times, half with reuse, half without, and for the last 36 of those (most of them done since posting the question - all happened on the same machine in essentially direct succession), I have the detailed results.

The maximum achievable training results seem to be comparable, but for ReLU reuse (‘shared’) the results are less predictable (more frequently has noticeably worse training results than expected). I use a custom version of ReduceLRonPlateau as a learning rate scheduler, and I observe that this regularly ‘saves’ runs with ReLU reuse from getting stuck on higher losses by waiting around long enough with a higher learning rate that it can “by chance” dig its way out of the hole it’s in. So even though the final results of ‘shared’ runs are sometimes okay, the path there is long and difficult.

I’m not sure how to quantify my observations from the 50 different training logs, and the observed differences, with cold hard comparable numbers… But in any case, the worst case performance of FullSize for Individual (no ReLU reuse) was 20.47% (lower is better), compared to three (of 9) Shared runs that were worse (up to 21.51%). And the worst case performance of HalfSize for Individual was 21.98%, compared to 3 of 9 Shared runs that were worse, up to 23.81%. This is far from the consistent 20-22% that Individual brings.

Any ideas? Or just statistical noise and bad luck?

A sureshot empirical way to test this would be to make the code completely reproducible, only changing the way the ReLU module is used. In order to make code reproducible, use the below lines of code.

torch.manual_seed(5) ## Any arbitarary number
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Do note that the only change in the model should be the use of ReLU layers. You can even try changing the forward function as shown below to prevent something funny from happening.

def forward(self, x,same_relu = True):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.conv2(x)
        if(same_relu):
          x = self.relu(x)
        else:
          x = self.relu_2(x)
        return x
1 Like

The order of ActFunc, which seems to be sorted by the loss, look pretty random.
You have some clusters there, which might come from other hyperparameters, e.g. the random init.

As @charan_Vjy said, try to compare the outputs/results with a deterministic run.

Good idea, I’m trying to get my runs to be deterministic at the moment, as suggested. I have:

random.seed(manual_seed)
np.random.seed(manual_seed)
torch.manual_seed(manual_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

I’m not sure if it’s necessary, but I also set num_workers=0 to avoid possible indeterminism from multiple data loaders. The losses look very reproducible, but after 5 epochs the losses from multiple runs start to differ in the fourth decimal place, indicating that something is still a small indeterminism problem.

I looked at: Reproducibility — PyTorch 2.1 documentation
This says:

A number of operations have backwards that use atomicAdd , in particular torch.nn.functional.embedding_bag() , torch.nn.functional.ctc_loss() and many forms of pooling, padding, and sampling. There currently is no simple way of avoiding non-determinism in these functions.

I use Conv2d with padding=1, BatchNorm2d, ReLU, ConvTranspose2d, and MaxPool2d. Are any of these candidates for unavoidable non-determinism?

I also ran with export PYTHONHASHSEED=0. What else could I be missing?

So I’m still having trouble getting two successive GPU training runs (identical process launches) of only 5 epochs to output exactly the same resulting model in the end. It’s close-ish, but not the same.

If I run the exact same thing on CPU then I get the same outputs to the 16th decimal place, so all good.
If I run the training on the cuda:0 device however, the losses start to differ with time, but only slightly.

I can confirm:

  • PYTHONHASHSEED is 0 in the environment
  • All the manual seeding etc from the previous post is happening prior to lazy initialisation of CUDA, and before any model/tensor/training code is reached (there’s some command line argument processing code etc prior to it)
  • Everything is running strictly in a single python process (no extra data loader workers or whatever)
  • Prior to training I use torch.rand() to generate a ‘random’ input to the model and evaluate the untrained initial model on this input. Both the input and output tensors are completely deterministic and the same value in each run (judged by printing .mean(), .std(), .median(), .min(), .max() and a fixed element of the tensor to 16 decimal places => all always identical)
  • After training for 5 epochs I do exactly the same again, and the ‘random’ input torch tensor is identical in each run (i.e. torch.rand() is confirmed to still be perfectly deterministic after training), but the evaluated output tensor differs in the second significant digit and beyond in the printed mean/std/etc values)

As I said, the exact same code running on the cpu device is completely deterministic from start to finish, so what’s the difference when running on CUDA, given that the cuDNN backend should actually be deterministic and with benchmark mode off?

Try setting torch.cuda.manualseed(same arbitrary number). Also, do share the results of same relu vs different relu.

As expected, additionally calling torch.cuda.manual_seed or torch.cuda.manual_seed_all does not make a difference because torch.manual_seed already calls it internally. torch.rand(*size, device='cuda') is provenly operating deterministically before and after training, so I think the RNG is alright. Because the errors are only slight and gradual I suspect it is an undeterministic order of operations issue (or similar), e.g. non-determinism in what order 100 numbers are internally added, leading to different floating point errors. What could be the cause? Could it be that atomicAdd issue for example?

Running on CPU the ReLU reuse and non-reuse lead to the same results, but as far as I can tell this does not allow us to safely conclude that this is also the case when running on CUDA, as demonstrated by the difference in behaviour already exhibited by the lacking determinism. It is a good first sanity check though.

Any ideas?

Some layers might not yield deterministic results due to the usage of atomicAdd.
You could try to remove these layers from your model and recheck the results or use this simple code snippet to verify that both relu approaches yield the same result:

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.backends.cudnn.deterministic = True
torch.manual_seed(2809)

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
        self.conv2 = nn.Conv2d(6, 12, 3, 1, 1)
        self.linear = nn.Linear(12*24*24, 100)
        self.act = nn.ReLU()

    def forward(self, x, use_act):
        if use_act:
            out = self.act(self.conv1(x))
            out = self.act(self.conv2(out))
        else:
            out = F.relu(self.conv1(x))
            out = F.relu(self.conv2(out))

        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

model = MyModel().cuda()
x = torch.randn(64, 3, 24, 24).cuda()

out = model(x, use_act=True)
out.mean().backward()

out_reference = out.clone()
grads_ref = []
for p in model.parameters():
    grads_ref.append(p.grad.clone())


model.zero_grad()
out = model(x, use_act=False)
out.mean().backward()

grads = []
for p in model.parameters():
    grads.append(p.grad.clone())

# Compare
print('output allclose: {}, max abs diff: {}'.format(torch.allclose(out_reference, out), (out_reference - out).abs().max()))
print('grads allclose: {}, max abs diff: {}'.format(
    all([torch.allclose(gr, g) for gr, g in zip(grads_ref, grads)]), max([(gr - gr).abs().max() for gr, g in zip(grads_ref, grads)])))

> output allclose: True, max abs diff: 0.0
> grads allclose: True, max abs diff: 0.0

I continually tweaked my network until I was finally able to get deterministic results. I removed the MaxPools first (even though I don’t think this uses atomicAdd), and as expected it changed nothing, but when I then also removed the convTranspose2d’s (being left with only conv2d, BN, ReLU for the entire network), the results were deterministic.

Comparing ReLU reuse in this configuration yields the exact same results. Thus, I can check this issue off as having been statistically random noise/bad luck in terms of the seemingly different results that I received.

Conclusion: Simple stateless ReLUs can be reused without changing anything (whether memory use or the numerical results).

7 Likes

Although ReLU does not have learnable parameters, shouldnt it still affect the backprop in a different way if we reused the same ReLU

No, it shouldn’t as ReLU is just calling into a stateless function (max(0, x) ).

During back-prop we evaluate the ReLU function at the input location. In other words, if x is negative the slope is 0. If x is positive, the slope is 1. I don’t know the details of how back-prop is implemented in PyTorch, but I’m concerned that by leveraging the same ReLU object the associated partial derivative may be over-written or added to by prior invocations of the ReLU during back-propagation resulting in incorrect calculations.

I think this proves there is no effect on the back-prop evaluations when referencing the same activation function object in a model.

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.activation = torch.nn.LeakyReLU(negative_slope = 0.01)
        self.w1 = torch.nn.Parameter(torch.Tensor([-1]))
        self.w2 = torch.nn.Parameter(torch.Tensor([-1]))
    
    def forward(self, x: torch.Tensor):
        a = self.w1 * x
        b = self.activation(a)
        
        c = self.w2 * b
        d = self.activation(c)
        
        # Store tensor reference & retain gradients to analyze post back-prop.
        self.a, self.b, self.c, self.d = a, b, c, d
        a.retain_grad()
        b.retain_grad()
        c.retain_grad()
        d.retain_grad()
        
        return d

x = 1
y = 5

# We expect the following values, particularly db/da and dd/dc assuming that 
# the activation function derivative evaluation is not affected by re-using the
# same activation function object:
#     a = w1 * x = -1 * 1 = -1
#     b = LeakyReLU(a) = LeakyReLU(-1) = -0.01
#     c = w2 * b = -1 * -0.01 = 0.01
#     d = LeakyReLU(c) = LeakyReLU(0.01) = 0.01
#     db/da = 0.01 (given a < 0 & LeakyReLU parameterized w/ negative slope = 0.01)
#     dd/dc = 1.0 (given a > 0)


model = Model()
y_hat = model(torch.Tensor([x]))
loss = (y - y_hat) ** 2
loss.backward()


a, b, c, d = [round(_, 4) for _ in (model.a.item(), model.b.item(), model.c.item(), model.d.item())]
a_grad, b_grad, c_grad, d_grad = [round(_, 4) for _ in (model.a.grad.item(), model.b.grad.item(), model.c.grad.item(), model.d.grad.item())]

print(f"{a = }, {b = }, {c = }, {d = }")
print(f"{a_grad = }, {b_grad = }, {c_grad = }, {d_grad = }", "\n")

# (dLoss/dd) * dd/dc = dLoss/dc
# dd/dc = (dLoss/dc) / (dLoss/dd)

# (dLoss/db) * db/da = dLoss/da
# db/da = (dLoss/da) / (dLoss/db)

db_da = round(a_grad / b_grad, 4)
dd_dc = round(c_grad / d_grad, 4)
print(f"{db_da = }, {dd_dc = }")