nn.RNN with ReLU backprops differently under CPU and GPU

nn.RNN with relu nonlinearity gives different gradients when using CPU and when using GPU.

Other recurrent modules do not have this problem.

Below is a minimum script to reproduce. If you copy&paste the script and save it as rnn_test.py, you can run with the following commands:

python3 rnn_test-2.py --model rnn --nonlinearity relu
python3 rnn_test-2.py --model rnn --nonlinearity relu --use_gpu
python3 rnn_test-2.py --model rnn --nonlinearity tanh
python3 rnn_test-2.py --model rnn --nonlinearity tanh --use_gpu
python3 rnn_test-2.py --model gru
python3 rnn_test-2.py --model gru --use_gpu
python3 rnn_test-2.py --model lstm
python3 rnn_test-2.py --model lstm --use_gpu

You will probably notice that outputs of each pair of the commands (one with --use_gpu and one without) are the same except for the first pair with relu nonlinearity.

I’m using python3.5.2 and pytorch 0.4.1.

Script:

import torch
import torch.nn as nn
from torch.autograd import Variable
import random
from copy import deepcopy
import argparse

parser = argparse.ArgumentParser(description="rnn cpu and gpu tests")
parser.add_argument('--use_gpu', action='store_true')
parser.add_argument('--model', type=str, default='rnn', choices=['rnn', 'gru', 'lstm'])
parser.add_argument('--nonlinearity', type=str, default='relu', choices=['relu', 'tanh'])
use_gpu = parser.parse_args().use_gpu
model = parser.parse_args().model
nonlinearity = parser.parse_args().nonlinearity
print('use gpu.') if use_gpu else print('use cpu.')

torch.cuda.manual_seed(0)
torch.manual_seed(0)
random.seed(0)

## manually create input, target, initial hidden state and criterion
input = Variable(torch.randn(100, 64, 1).cuda()) if use_gpu else Variable(torch.randn(100, 64, 1)) # dim = (seq_len, batch_size, input_size)
target = Variable(torch.randint(low=0, high=1, size=(64, ), dtype=torch.long).cuda()) if use_gpu else Variable(torch.randint(low=0, high=1, size=(64, ), dtype=torch.long))
hx0 = Variable(torch.randn(64, 20).cuda()) if use_gpu else Variable(torch.randn(64, 20)) # dim = (batch_size, hidden_size)
if model == 'lstm':
    c0 = Variable(torch.zeros(64, 20).cuda()) if use_gpu else Variable(torch.zeros(64, 20)) # dim = (batch_size, hidden_size)
criterion = nn.CrossEntropyLoss() # use cross entropy loss


## first network, its output and rnn gradients
if model=='rnn':
    rnn1 = nn.RNNCell(1, 20, nonlinearity=nonlinearity, bias=False).cuda() if use_gpu else nn.RNNCell(1, 20, nonlinearity=nonlinearity, bias=False)
elif model=='gru':
    rnn1 = nn.GRUCell(1, 20, bias=False).cuda() if use_gpu else nn.GRUCell(1, 20, bias=False)
elif model=='lstm':
    rnn1 = nn.LSTMCell(1, 20, bias=False).cuda() if use_gpu else nn.LSTMCell(1, 20, bias=False)
linear1 = nn.Linear(20, 2, bias=False).cuda() if use_gpu else nn.Linear(20, 2, bias=False)

# no bias and eye init to make sure two networks have the same parameters
for name, param in rnn1.named_parameters():
    if 'weight' in name:
        nn.init.eye_(param)
for name, param in linear1.named_parameters():
    if 'weight' in name:
        nn.init.eye_(param)

# run the net
hx1 = deepcopy(hx0)
if model=='lstm':
    c1 = deepcopy(c0)
output1 = []
for i in range(100):
    if model != 'lstm':
        hx1 = rnn1(input[i], hx1)
    else:
        hx1, c1 = rnn1(input[i], (hx1, c1))
    output1.append(hx1)
logit1 = linear1(hx1)
loss1 = criterion(logit1, target)

# calculate gradients and sum of gradient norms
grad_params1 = torch.autograd.grad(loss1, rnn1.parameters(), create_graph=True, retain_graph=True, allow_unused=True)
grad_norm1 = 0
for idx in range(len(grad_params1)):
    grad_norm1 += torch.norm(grad_params1[idx])
print('rnn1 - loss: %f' % (loss1))
print('rnn1 -  sum of gradient norm is: %f' % (grad_norm1))
print('---')

## second network, its output and rnn gradients
## first network, its output and rnn gradients
if model=='rnn':
    rnn2 = nn.RNN(1, 20, nonlinearity=nonlinearity, bias=False).cuda() if use_gpu else nn.RNN(1, 20, nonlinearity=nonlinearity, bias=False)
elif model=='gru':
    rnn2 = nn.GRU(1, 20, bias=False).cuda() if use_gpu else nn.GRU(1, 20, bias=False)
elif model=='lstm':
    rnn2 = nn.LSTM(1, 20, bias=False).cuda() if use_gpu else nn.LSTM(1, 20, bias=False)
linear2 = nn.Linear(20, 2, bias=False).cuda() if use_gpu else nn.Linear(20, 2, bias=False)

# same init as the first network
for name, param in rnn2.named_parameters():
    if 'weight' in name:
        nn.init.eye_(param)
for name, param in linear2.named_parameters():
    if 'weight' in name:
        nn.init.eye_(param)
        
# run the net 
if model != 'lstm':
    output2, hx2 = rnn2(input, hx0.unsqueeze(0))
else:
    output2, (hx2, _) = rnn2(input, (hx0.unsqueeze(0), c0.unsqueeze(0)))
logit2 = linear2(hx2[-1])
loss2 = criterion(logit2, target)

# calculate gradients and sum of gradient norms
grad_params2 = torch.autograd.grad(loss2, rnn2.parameters(), create_graph=True, retain_graph=True, allow_unused=True)
grad_norm2 = 0
for idx in range(len(grad_params2)):
    grad_norm2 += torch.norm(grad_params2[idx])
print('rnn2 - loss: %f' % (loss2))
print('rnn2 - sum of gradient norm is: %f' % (grad_norm2))

for anyone coming here, this is a bug.
We are tracking it here https://github.com/pytorch/pytorch/issues/11662

Thanks a lot for reporting @moonlightlane

For anyone stumbling upon this:

Turns out it isn’t a bug per se but it is a corner case of ReLU’s gradient at the (non-differentiable) point 0 - CuDNN says it’s one (the limit from the right), “native” PyTorch has been switched to zero (the limit from the left).
Demonstration code is in the bug report.

Now, with a ReLU-RNN I would expect that corner case to be somewhat relevant and it makes me wonder if that makes ReLU-RNN actually a bit of a bad idea.

Best regards

Thomas

yeah and the difference is more drastic when you are not using bias. but even if you use bias, there are situations that the outputs are different… I guess there is no correct way to determine the derivative of ReLU at zero because you can define it either way.