Softplus derivative

Looking at pytorch sources, I see two implementations:

caffe (no longer used?)

const float nexpY = exp(-Y[i]);
dX[i] = dY[i] * (1 - nexpY);

aten (a=X, b=Y)

scalar_t z = std::exp(b * beta);
return (b * beta) > threshold ? a : a * (z - scalar_t(1.)) / z;

below threshold case is equiv. to caffe:

 dX = dY * (exp(Y)-1)/exp(Y) = dY * (1 - exp(-Y))

Notice that caffe version doesn’t need the X tensor. So, are there any problems with using that? (in the most common case, with beta=1, and hessian not needed). Is caffe implementation built-in/exported to python?

To be honest, I don’t think self (the input to the forward) actually used:

My bad, “a” above binds to grad_output. So, I presume capturing of “self” qualifies as a bug then.

On second look, it seems it is done for softplus_double_backward. Interesting tradeoff, simpler (I assume) formula there vs extra memory for a regular case…

So one thing one could see if the JIT decomposes this more nicely (I don’t know, but it could be, as that has it’s own differentiation scheme and eliminates unneeded variables).

Nope. Here is my test script and alternative version.

from torch import *
from torch import jit
from torch.nn import functional as F
import gc

#no effect on results

def jit_softplus(x):
	return F.softplus(x)

def fused_softplus(x):
	return x.exp().add(1.0).log()

class LowMemSoftplus(torch.autograd.Function):
	def forward(ctx, x):
		y = F.softplus(x)
		return y
	def backward(ctx, dLoss_dOutput):
		softplus_of_x, = ctx.saved_tensors
		dOutput_dInput = softplus_of_x.neg().exp_()
		torch.sub(1.0, dOutput_dInput, out=dOutput_dInput)
		return dLoss_dOutput * dOutput_dInput

for title, f in {
	"nograd": F.softplus, "grad": F.softplus,
	"jitgrad": jit_softplus, "jitfuse": fused_softplus,
	"custom": LowMemSoftplus.apply,
	mem0 = torch.cuda.memory_allocated()
	p = torch.ones(5000, device="cuda").requires_grad_(title!="nograd") #this leaf tensor is kept alive due to .grad
	x = p[:, None] @ p[None, :] #non leaf tensor with 25m elements
	y = f(x)
	del x,f
	print(title,":",torch.cuda.memory_allocated() - mem0)
	del y,p
	assert torch.cuda.memory_allocated() == mem0


nograd : 100683776
grad : 201347072
jitgrad : 201347072
jitfuse : 302010368
custom : 100683776
1 Like