Implementation of arbitrary differentiable functions


#1

Is it possible to just use arbitrary differentiable/supported functions to create other functions without having to implement their backward as described in examples?

One thing I really liked about TF is how you can just create an arbitrary compute graph of differentiable pieces. It’s not obvious how to do that here unless I’m missing something?

Lets say I want to quickly implement a GELU: y = 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))))
Can I just do that using the differentiable tanh() and pow() ? Or will I have to create a special class and describe backward()?


(James Bradbury) #2

That code implements GELU (well, with tanh replaced by F.tanh etc.). If you want to wrap it in a Python def, that works too.


#3

I see, so what then is the reason behind “def backward(self, grad_output):” when implementing your own modules? Why isn’t that redundant?


(James Bradbury) #4

You need to define backward if you’re implementing your own autograd.Function classes, not for your own Module classes. The difference is that the code in Module.forward operates on Variables using differentiable operations like F.tanh and other Modules, while you need to define a new autograd.Function subclass only if you want to define a totally new operation that can’t be written in terms of differentiable ops. It’s also helpful to define an autograd.Function rather than composing existing differentiable operations if the forwards or backwards passes would see a major performance benefit from custom C implementations.

Ultimately everything you use in a module is defined in terms of autograd.Functions (e.g. F.tanh implements forward and backward) but you rarely have to define one yourself.


#5

It seems that you can’t use numpy constants in such a construct, it leads it being stuck on the CPU :

def gelu(x):
	return 0.5 * x * (1 + F.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x*x*x)))

However if you replace the square root with 0.79788456080 everything works fine. Intentional?


(Adam Paszke) #6

What’s the error you’re getting on the GPU? You’re only doing scalar ops from numpy.


#7

Not getting any error. It just seems to never get around to training, GPU isn’t busy and a single CPU core is at 100%.


(Adam Paszke) #8

And what’s the stack trace once you interrupt the script?


#9
^CProcess Process-1:
Traceback (most recent call last):
  File "/home/rrr/anaconda/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
Traceback (most recent call last):
  File "main.py", line 121, in <module>
    self.run()
  File "/home/rrr/anaconda/lib/python2.7/multiprocessing/process.py", line 114, in run
    self._target(*self._args, **self._kwargs)
  File "/home/rrr/anaconda/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 26, in _worker_loop
    output = net(input)
  File "/home/rrr/anaconda/lib/python2.7/site-packages/torch/nn/modules/module.py", line 202, in __call__
    r = index_queue.get()
  File "/home/rrr/anaconda/lib/python2.7/multiprocessing/queues.py", line 378, in get
    result = self.forward(*input, **kwargs)
  File "main.py", line 79, in forward
    x = gelu(self.fc1(x))
  File "main.py", line 61, in gelu
    return 0.5 * x * (1 + F.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x*x*x)))
  File "/home/rrr/anaconda/lib/python2.7/site-packages/torch/autograd/variable.py", line 818, in __iter__
    return recv()
  File "/home/rrr/anaconda/lib/python2.7/site-packages/torch/multiprocessing/queue.py", line 21, in recv
    buf = self.recv_bytes()
KeyboardInterrupt
    return iter(map(lambda i: self[i], range(self.size(0))))
  File "/home/rrr/anaconda/lib/python2.7/site-packages/torch/autograd/variable.py", line 818, in <lambda>
    return iter(map(lambda i: self[i], range(self.size(0))))
  File "/home/rrr/anaconda/lib/python2.7/site-packages/torch/autograd/variable.py", line 68, in __getitem__
    return Index(key)(self)
  File "/home/rrr/anaconda/lib/python2.7/site-packages/torch/autograd/_functions/tensor.py", line 16, in forward
    result = i.index(self.index)
KeyboardInterrupt

(Adam Paszke) #10

I think it’s an unrelated problem. You sent a tensor from one process to another, but before the receiver has managed to take it out of the queue, the sender has already died. You either have to ensure that the sender is alive as long as it’s tensors are in a queue, or you have to switch to a file_system sharing strategy (not really recommended).


#11

Well I’m not sure what the problem is but changing np.sqrt(2 / np.pi) to its value immediately fixes the problem. So I can’t say it’s unrelated.


(Adam Paszke) #12

I looked into the snippet and it seems that numpy is trying to be overly smart. np.sqrt(2 / np.pi) is a numpy.float64 object, not a regular float. Since it’s the first argument to multiplication, and apparently implements __mul__, then it can decide on what to do now. And apparently it starts treating Variables like sequences, but Variables are not regular sequences, because you can index them as many times as you want. That’s why it keeps adding dims, until it hits the numpy limit, and returns a very deeply nested list that contains single element Variables.

If you create a regular float object out of it, or reverse the order (i.e. put the constant after the expression with x the result should be ok).

I think the only fix we can do is to add scalar types to torch. We’ve been talking about that for some time now, and it’s probably going to happen, but rather in some farther future.