Hello,
I have been working on a paper dealing with new activation functions. While the activation functions are working, they occupy a considerable amount of memory to the point where they are practically unusable. I wondered if I was missing some programming PyTorch trick that could cut down memory usage. Let me show you an example of a legacy function we came up with but are not using anymore. It illustrates the problem well.
MIN_FRAC_ORDER = 0.1
MAX_FRAC_ORDER = 1.9
def push_from_integer(x: torch.Tensor, t=1e-6):
int_x = torch.round(x)
d = int_x - x
if torch.abs(d) > t:
return x
if d < 0:
return int_x + t
elif d >= 0:
return int_x - t
else:
raise ValueError(f"{x}, {int_x}, {d}")
def gamma(x: torch.Tensor, t: float = 0.01):
if x < t:
x = push_from_integer(x, t)
return torch.pi / (torch.sin(torch.pi * x) * gamma(1 - x))
else:
return torch.exp(torch.lgamma(x))
class NAF1(nn.Module):
__constants__ = ["_alpha", "_lambda", "num_parameters"]
num_parameters: int
_pool = []
def __init__(self, alpha: float = 1.6732, init: float = 0.0, Lambda: float = 1.0507, h: float = 1.0):
self.num_parameters = 1
self._alpha = alpha
self._lambda = Lambda
super().__init__()
self._h = h
self.init = init
self.frac_order = nn.Parameter(torch.tensor(init), requires_grad=True)
gamma_consts = [gamma(torch.tensor(el, dtype=torch.float32)) for el in range(1, 4)]
n_consts = [torch.tensor(el, dtype=torch.float32) for el in range(0, 3)]
self.register_buffer("gamma_consts", gamma_consts, persistent=False)
self.register_buffer("n_consts", n_consts, persistent=False)
NAF1._pool.append(self)
def forward(self, x):
# squeeze variable to MIN/MAX_FRAC_ORDER
frac_order = torch.sigmoid(self.frac_order) * (MAX_FRAC_ORDER - MIN_FRAC_ORDER) + MIN_FRAC_ORDER
tensor_pos = torch.where(x > 0, x, 1.0)
tensor_pos = self._lambda * torch.pow(tensor_pos, 1 - frac_order) / gamma(2-frac_order)
tensor_neg = torch.where(x <= 0, x, 1.0)
sum = self.__sum(tensor_neg)
tensor_neg = self._alpha * self._lambda * 1/torch.pow(self._h, frac_order) * sum
output = torch.where(x > 0, tensor_pos, tensor_neg)
return output
def __sum(self, x):
frac_order = torch.sigmoid(self.frac_order) * (MAX_FRAC_ORDER - MIN_FRAC_ORDER) + MIN_FRAC_ORDER
n = self.n_consts[0]
top = (1 - torch.exp(x-n*self._h)) * gamma(1+frac_order)
bot = self.gamma_consts[0] * gamma(1-n+frac_order)
acum = top/bot*torch.pow(-1, n)
n = self.n_consts[1]
top = (1 - torch.exp(x-n*self._h)) * gamma(1+frac_order)
bot = self.gamma_consts[1] * gamma(1-n+frac_order)
acum = acum + top/bot*torch.pow(-1, n)
n = self.n_consts[2]
top = (1 - torch.exp(x-n*self._h)) * gamma(1+frac_order)
bot = self.gamma_consts[2] * gamma(1-n+frac_order)
acum = acum + top/bot*torch.pow(-1, n)
return acum
The key facts here are:
- The input tensor is divided into positive and negative parts inside
forward()
function. According to this post, that is memory-consuming. - Function
__sum()
contains cyclic computation with positive or negative tensor and trainable parameter. I believe this is the main memory-consuming suspect here.
Am I right or wrong in my suspicion? Any thoughts?
Thank you.
EDIT1: 2. point edited.