When using torch.nn.functional.elu and torch.nn.ELU on CPU, performing:
- slice → ELU
and - ELU → slice
produce different results. Expected: exactly the same output (same bytes). [slice → clone → ELU] produces the same output as [slice → ELU]
The code block below produces the following output on a particular Linux machine:
processor: x86_64
OS: Linux-5.15.0-76-generic-x86_64-with-glibc2.29
numpy version: 1.23.5
torch version: 1.13.0+cu117
pair : nonzeros/total (expect 0/10000)
u0, u1 (F.elu): 9989/10000
u2, u3 (F.elu): 2371/10000
v0, v1 (MyELU): 0/10000
v2, v3 (MyELU): 0/10000
v0, u0 (F, My): 5173/10000
Average squared difference between F.elu (u0) and MyELU (v0): 2.6055602120322876e-15
where as it produces this output on a particular Windows machine:
processor: AMD64 Family 25 Model 80 Stepping 0, AuthenticAMD
OS: Windows-10-10.0.19041-SP0
numpy version: 1.20.1
torch version: 1.10.0+cu111
pair : nonzeros/total (expect 0/10000)
u0, u1 (F.elu): 0/10000
u2, u3 (F.elu): 0/10000
v0, v1 (MyELU): 0/10000
v2, v3 (MyELU): 0/10000
v0, u0 (F, My): 5171/10000
Average squared difference between F.elu (u0) and MyELU (v0): 2.589573000477685e-15
The expected result is that all fractions should be 0/10000, except for maybe the v0, u0 comparison due to implementation differences.
I also found that results are consistent when the object that F.elu is applied on is the same but computation is repeated.
# code block
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import platform
print(f'processor: {platform.processor()}')
print(f'OS: {platform.platform()}')
print(f'numpy version: {np.__version__}')
print(f'torch version: {torch.__version__}')
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class MyELU(nn.Module):
def __init__(self, alpha=1.0):
super().__init__()
self.register_buffer('alpha', torch.tensor(alpha))
return
def forward(self, x):
return torch.where(x >= 0, x, self.alpha*(torch.exp(torch.minimum(x, torch.tensor(0.0))) - 1))
my_elu = MyELU(alpha=1.0)
for p in my_elu.parameters():
p.requires_grad_(False)
my_elu.eval()
num_diffs_u0_u1 = []
num_diffs_u2_u3 = []
num_diffs_v0_v1 = []
num_diffs_v2_v3 = []
num_diffs_v0_u0 = []
diffs_mag_v0_u0 = []
for _ in range(10000):
z1 = torch.randn(200, 2)
z0 = z1[:, 0:1]
z3 = torch.randn(2, 200)
z2 = z3[0:1, :]
with torch.no_grad():
u0 = F.elu(z0)
u1 = F.elu(z1)[:, 0:1]
u2 = F.elu(z2)
u3 = F.elu(z3)[0:1, :]
v0 = my_elu(z0)
v1 = my_elu(z1)[:, 0:1]
v2 = my_elu(z2)
v3 = my_elu(z3)[0:1, :]
num_diff = (u0 != u1).float().sum()
num_diffs_u0_u1.append(num_diff.item())
num_diff = (u2 != u3).float().sum()
num_diffs_u2_u3.append(num_diff.item())
num_diff = (v0 != v1).float().sum()
num_diffs_v0_v1.append(num_diff.item())
num_diff = (v2 != v3).float().sum()
num_diffs_v2_v3.append(num_diff.item())
num_diff = (v0 != u0).float().sum()
num_diffs_v0_u0.append(num_diff.item())
diffs_mag_v0_u0.append(((v0 - u0)**2).sum().item())
print('pair : nonzeros/total (expect 0/10000)')
for identifier, result in zip(
['u0, u1 (F.elu)', 'u2, u3 (F.elu)', 'v0, v1 (MyELU)', 'v2, v3 (MyELU)', 'v0, u0 (F, My)'],
[num_diffs_u0_u1, num_diffs_u2_u3, num_diffs_v0_v1, num_diffs_v2_v3, num_diffs_v0_u0]
):
sum_nonzero = (np.array(result) != 0).sum()
print(f'{identifier}: {sum_nonzero:8}/{len(result):5}')
#
print(f'Average squared difference between F.elu (u0) and MyELU (v0): {np.mean(diffs_mag_v0_u0)}')