Custom trainable activation function memory consumption

Hello,

I have been working on a paper dealing with new activation functions. While the activation functions are working, they occupy a considerable amount of memory to the point where they are practically unusable. I wondered if I was missing some programming PyTorch trick that could cut down memory usage. Let me show you an example of a legacy function we came up with but are not using anymore. It illustrates the problem well.

MIN_FRAC_ORDER = 0.1
MAX_FRAC_ORDER = 1.9

def push_from_integer(x: torch.Tensor, t=1e-6):
    int_x = torch.round(x)
    d = int_x - x
    if torch.abs(d) > t:
        return x
    if d < 0:
        return int_x + t
    elif d >= 0:
        return int_x - t
    else:
        raise ValueError(f"{x}, {int_x}, {d}")


def gamma(x: torch.Tensor, t: float = 0.01):
    if x < t:
        x = push_from_integer(x, t)
        return torch.pi / (torch.sin(torch.pi * x) * gamma(1 - x))
    else:
        return torch.exp(torch.lgamma(x))


class NAF1(nn.Module):
    __constants__ = ["_alpha", "_lambda", "num_parameters"]
    num_parameters: int
    _pool = []
  
    def __init__(self, alpha: float = 1.6732, init: float = 0.0, Lambda: float = 1.0507, h: float = 1.0):
        self.num_parameters = 1
        self._alpha = alpha
        self._lambda = Lambda
        super().__init__()
  
        self._h = h
        self.init = init
        self.frac_order = nn.Parameter(torch.tensor(init), requires_grad=True)
        gamma_consts = [gamma(torch.tensor(el, dtype=torch.float32)) for el in range(1, 4)]
        n_consts = [torch.tensor(el, dtype=torch.float32) for el in range(0, 3)]
        self.register_buffer("gamma_consts", gamma_consts, persistent=False)
        self.register_buffer("n_consts", n_consts, persistent=False)
        NAF1._pool.append(self)
  
    def forward(self, x):
        # squeeze variable to MIN/MAX_FRAC_ORDER
        frac_order = torch.sigmoid(self.frac_order) * (MAX_FRAC_ORDER - MIN_FRAC_ORDER) + MIN_FRAC_ORDER
        tensor_pos = torch.where(x > 0, x, 1.0)
        tensor_pos = self._lambda * torch.pow(tensor_pos, 1 - frac_order) / gamma(2-frac_order)
  
        tensor_neg = torch.where(x <= 0, x, 1.0)
        sum = self.__sum(tensor_neg)
        tensor_neg = self._alpha * self._lambda * 1/torch.pow(self._h, frac_order) * sum
        output = torch.where(x > 0, tensor_pos, tensor_neg)
        return output
  
    def __sum(self, x):
        frac_order = torch.sigmoid(self.frac_order) * (MAX_FRAC_ORDER - MIN_FRAC_ORDER) + MIN_FRAC_ORDER
  
        n = self.n_consts[0]
        top = (1 - torch.exp(x-n*self._h)) * gamma(1+frac_order)
        bot = self.gamma_consts[0] * gamma(1-n+frac_order)
        acum = top/bot*torch.pow(-1, n)
  
        n = self.n_consts[1]
        top = (1 - torch.exp(x-n*self._h)) * gamma(1+frac_order)
        bot = self.gamma_consts[1] * gamma(1-n+frac_order)
        acum = acum + top/bot*torch.pow(-1, n)
  
        n = self.n_consts[2]
        top = (1 - torch.exp(x-n*self._h)) * gamma(1+frac_order)
        bot = self.gamma_consts[2] * gamma(1-n+frac_order)
        acum = acum + top/bot*torch.pow(-1, n)
  
        return acum

The key facts here are:

  1. The input tensor is divided into positive and negative parts inside forward() function. According to this post, that is memory-consuming.
  2. Function __sum() contains cyclic computation with positive or negative tensor and trainable parameter. I believe this is the main memory-consuming suspect here.

Am I right or wrong in my suspicion? Any thoughts?

Thank you.

EDIT1: 2. point edited.

How much memory is this method using now and what’s your expectation? Did you estimate the memory usage by summing all tensor shapes needed for the forward and backward computation? If so, how large is the divergence between your estimation and the real usage?

How much memory is this method using now and what’s your expectation?

Currently, we have two functions: one that splits the input tensor (as shown in forward() function) into positive and negative parts and one that does not, i.e., works with the whole input tensor.

Models with activation function that separates input have the following memory consumption: 1955MiB (simple CNN), 10517MiB (ResNet-18), OOM (ResNet-50), and 26059MiB (EffNetB0-like architecture).

Models with an activation function that doesn’t separate the input into positive and negative have the following memory consumption: 1897MiB (simple CNN), 5623MiB (ResNet-18), 20503 (ResNet-50), and 15775MiB (EffNetB0-like architecture).

For reference, using ReLU, we get 1867MiB (simple CNN), 2961MiB (ResNet-18), 6081MiB (ResNet-50), and 5017MiB (EffNetB0-like architecture).

These values are taken by running a training loop on CIFAR-10 with batch size 128.
As for my expectations, I would be happy to reduce memory consumption as much as possible. I’m aware that, most likely, not much can be done. I just wanted to check with you whenever I missed something obvious (I’m a PyTorch newbie).

Did you estimate the memory usage by summing all tensor shapes needed for the forward and backward computation?

I did not. Didn’t cross my mind. Does that mean that, for example, in the function __sum() each time I assign variable top a new value (in this case 3 times), a new tensor is created that contributes to the memory usage? That would explain a lot.

Not only for the “top” variable. Each Operation that cannot bei differentiated in place (Like relu) creates a copy that has to be stored for the backwards pass. But even relu has to store at least the positions that have been cancelled

Can arithmetic operations such as multiplication, addition, and so on be done in place?

Deleted my garbage from before

The correct answer ist Here

You need only the incoming gradients for Addition. But for concatenation you need the inputs. In short: it should not store intermediate Data.

1 Like

Thank you for the explanation, I consider this thread answered.

Ignore it. IT was wrong. Seems like my linear algebra ist too long ago :smiley:

:smiley: fair, we are in the same boat, so to speak.

Sry i was a bit in a hurry. I Hope i can think more clearly now :smiley:

The Gradient of x+y is (1,1), which shows that it does not depend on the inputs or on the output.

The Gradient of x*y is (y,x) which indeed depends on Input variables.

A third example is relu. This does not depend on the variables directly, but on their sign

A fourth one ist sigmoid, whose derivative is sigmoid*(1-sigmoid) which obviously depends only on the output instead of the Input variable

Sorry for the confusion