Deterministic behaviour gpu

Hi,
I’ve realised there is some sort of strange behaviour if I use max pooling in this model with the gpu version.
I’m making a batch of identical samples. They do have temporal features to be processed with 2D convolutions, thus i process everything in the batch dim.
I’ve seen that, if I use max_pool the output is different for the different elements of the batch, they are exactly the same otherwise.
This doesn’t happen with the CPU version but GPU only.
It happens for Titan GTX, 1080 Ti, Quadro P6000 with 2 different computers and setups.

import torch
import sys
import subprocess

from torch import nn
from torchaudio.transforms import MelSpectrogram, Spectrogram
import torch

N_FFT = 512
N_MELS = 256
HOP_LENGTH = 130
AUDIO_FRAMERATE = 16000


def get_sys_info():
    """

    :param log: Logging logger in which to parse info
    :type log: logging.logger
    :return: None
    """
    result = subprocess.Popen(["nvidia-smi", "--format=csv",
                               "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free"],
                              stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    nvidia = result.stdout.readlines().copy()
    nvidia = [str(x) for x in nvidia]
    nvidia = [x[2:-3] + '\r\t' for x in nvidia]
    acum = ''
    for x in nvidia:
        acum = acum + x

    return ('  Python VERSION: {0} \n\t'
            '  pyTorch VERSION: {1} \n\t'
            '  CUDA VERSION: {2}\n\t'
            '  CUDNN VERSION: {3} \n\t'
            '  Number CUDA Devices: {4} \n\t'
            '  Devices: {5}\n\t'
            'Active CUDA Device: GPU {6} \n\t'
            'Available devices {7} \n\t'
            'Current cuda device {8} \n\t'.format(sys.version, torch.__version__, torch.version.cuda,
                                                  torch.backends.cudnn.version(), torch.cuda.device_count(),
                                                  acum, torch.cuda.current_device(), torch.cuda.device_count(),
                                                  torch.cuda.current_device()))


def make_audio_block(filter_in, filters_out, lrn, max_pool=None, padding=0, stride=1, kernel_size=(3, 3)):
    layers = [nn.Conv2d(filter_in, filters_out, kernel_size=kernel_size, padding=padding, stride=stride)]
    layers.append(nn.ReLU(False))
    if max_pool is not None:
        layers.append(nn.MaxPool2d(max_pool))
    return nn.Sequential(*layers)


def reshape(x, unrolled_shape):
    return x.view(*unrolled_shape, *x.shape[1:])


def check(x, unrolled_shape):
    x_unrolled = reshape(x, unrolled_shape)
    ref = x_unrolled[0]
    all_equal = True
    error_abs = [None for _ in range(unrolled_shape[0])]
    error_mean = [None for _ in range(unrolled_shape[0])]
    error_abs[0] = 0
    error_mean[0] = 0
    for i in range(1, unrolled_shape[0]):
        all_equal = all_equal and torch.allclose(ref, x_unrolled[i])
        diff = torch.abs(ref - x_unrolled[i])
        diff = diff[diff > 0]
        error_abs[i] = diff.sum()
        error_mean[i] = diff.mean()
    return all_equal, max(error_abs), max(error_mean)


class AudioEncoder(nn.Module):
    # 'filter_size': [96, 256, 512, 512*6*6]
    def __init__(self, pooling, pooling_type='AvgPool'):
        super(AudioEncoder, self).__init__()
        assert pooling_type in ['MaxPool', 'AvgPool'], f'Pooling of type{pooling_type} should be MaxPool or AvgPool'
        filters = [1, 32, 64, 128]
        # self.preproc = MelSpectrogram(sample_rate=AUDIO_FRAMERATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS)
        self.preproc = Spectrogram(n_fft=N_FFT, hop_length=HOP_LENGTH)
        self.b1 = make_audio_block(filters[0], filters[1], lrn=False, max_pool=pooling, padding=2, kernel_size=(7, 7),
                                   stride=(2, 1))
        self.b2 = make_audio_block(filters[1], filters[2], lrn=False, max_pool=pooling, padding=0, kernel_size=(7, 3))
        self.b3 = make_audio_block(filters[2], filters[3], lrn=False, padding=0, kernel_size=(7, 3))
        if pooling_type == 'MaxPool':
            self.pooling = nn.AdaptiveMaxPool2d((1, None))
        else:
            self.pooling = nn.AdaptiveAvgPool2d((1, None))

    def forward(self, x):
        verbose = True
        unrolled_shape = x.shape[:2]
        print(f'Unrolled shape: {unrolled_shape}')
        if verbose:
            print(f'Input--> Shape: {x.shape}, device:{x.device}')
        x = self.preproc(x)
        if verbose:
            print(f'FFT --> Shape: {x.shape}')
        x = x.view(-1, 1, *x.shape[2:])
        if verbose:
            equal, abs_max, mean_max = check(x, unrolled_shape)
            print(f'view --> Shape: {x.shape}, all equal: {equal}, max error: abs {abs_max},mean {mean_max}')
        x = self.b1(x)
        if verbose:
            equal, abs_max, mean_max = check(x, unrolled_shape)
            print(f'b1 --> Shape: {x.shape}, all equal: {equal}, max error: abs {abs_max},mean {mean_max}')
        x = self.b2(x)
        if verbose:
            equal, abs_max, mean_max = check(x, unrolled_shape)
            print(f'b2 --> Shape: {x.shape}, all equal: {equal}, max error: abs {abs_max},mean {mean_max}')
        x = self.b3(x)
        if verbose:
            equal, abs_max, mean_max = check(x, unrolled_shape)
            print(f'b3 --> Shape: {x.shape}, all equal: {equal}, max error: abs {abs_max},mean {mean_max}')
        x = self.pooling(x)
        if verbose:
            equal, abs_max, mean_max = check(x, unrolled_shape)
            print(f'pooling --> Shape: {x.shape}, all equal: {equal}, max error: abs {abs_max},mean {mean_max}')
        x = x.squeeze()
        return x, x.shape


BATCH_SIZE = 16


@torch.no_grad()
def run_test(pooling, pooling_type, device):
    device = torch.device(device)
    model = AudioEncoder(pooling, pooling_type).to(device)
    inp_element = torch.rand(25, 4480).to(device)
    inp = torch.stack([inp_element.clone() for _ in range(BATCH_SIZE)])

    # print(model)

    y, shape = model(inp)
    is_identical, max_abs, max_mean = check(y, [BATCH_SIZE, 25])

    if is_identical:
        print(f"Test: Pooling {pooling}, {pooling_type}. Device {device}, max_abs {max_abs},max_mean {max_mean} OK")
    else:
        print(
            f"Test: Pooling {pooling}, {pooling_type}. Device {device}, max_abs {max_abs},max_mean {max_mean}. Failed")
    print('---------------------------------------')


pooling_tests = [None, (3, 3)]
pooling_types = ['AvgPool']
devices = ['cuda:0', 'cuda:1', 'cpu']

if __name__ == '__main__':

    print(get_sys_info())
    for device in devices:
        for pooling_i in pooling_tests:
            for pooling_type_i in pooling_types:
                run_test(pooling_i, pooling_type_i, device)
  Python VERSION: 3.6.9 (default, Oct  8 2020, 12:12:24) 
[GCC 8.4.0] 
	  pyTorch VERSION: 1.7.0+cu110 
	  CUDA VERSION: 11.0
	  CUDNN VERSION: 8004 
	  Number CUDA Devices: 3 
	
	Active CUDA Device: GPU 0 
	Available devices 3 
	Current cuda device 0 
	
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cuda:0
/home/jfm/.local/lib/python3.6/site-packages/torch/functional.py:516: UserWarning: stft will require the return_complex parameter be explicitly  specified in a future PyTorch release. Use return_complex=False  to preserve the current behavior or return_complex=True to return  a complex output. (Triggered internally at  /pytorch/aten/src/ATen/native/SpectralOps.cpp:653.)
  normalized, onesided, return_complex)
/home/jfm/.local/lib/python3.6/site-packages/torch/functional.py:516: UserWarning: The function torch.rfft is deprecated and will be removed in a future PyTorch release. Use the new torch.fft module functions, instead, by importing torch.fft and calling torch.fft.fft or torch.fft.rfft. (Triggered internally at  /pytorch/aten/src/ATen/native/SpectralOps.cpp:590.)
  normalized, onesided, return_complex)
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 128, 33]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 122, 31]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 116, 29]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([400, 128, 1, 29]), all equal: True, max error: abs 0,mean 0
Test: Pooling None, AvgPool. Device cuda:0, max_abs 0,max_mean 0 OK
---------------------------------------
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cuda:0
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 42, 11]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 12, 3]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 6, 1]), all equal: False, max error: abs 0.030087150633335114,mean 3.579247049856349e-06
pooling --> Shape: torch.Size([400, 128, 1, 1]), all equal: False, max error: abs 0.0028738644905388355,mean 1.6516462437721202e-06
Test: Pooling (3, 3), AvgPool. Device cuda:0, max_abs 0.0028738644905388355,max_mean 1.6516462437721202e-06. Failed
---------------------------------------
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cuda:1
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 128, 33]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 122, 31]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 116, 29]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([400, 128, 1, 29]), all equal: True, max error: abs 0,mean 0
Test: Pooling None, AvgPool. Device cuda:1, max_abs 0,max_mean 0 OK
---------------------------------------
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cuda:1
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 42, 11]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 12, 3]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 6, 1]), all equal: False, max error: abs 0.03450433909893036,mean 4.106193046027329e-06
pooling --> Shape: torch.Size([400, 128, 1, 1]), all equal: False, max error: abs 0.0032321936450898647,mean 1.921637021951028e-06
Test: Pooling (3, 3), AvgPool. Device cuda:1, max_abs 0.0032321936450898647,max_mean 1.921637021951028e-06. Failed
---------------------------------------
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cpu
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 128, 33]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 122, 31]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 116, 29]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([400, 128, 1, 29]), all equal: True, max error: abs 0,mean 0
Test: Pooling None, AvgPool. Device cpu, max_abs 0,max_mean 0 OK
---------------------------------------
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cpu
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 42, 11]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 12, 3]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 6, 1]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([400, 128, 1, 1]), all equal: True, max error: abs 0,mean 0
Test: Pooling (3, 3), AvgPool. Device cpu, max_abs 0,max_mean 0 OK
---------------------------------------

Process finished with exit code 0

You meant average pooling or max pooling? Your script seems to be only testing average pooling no?

For average pooling, I could see different batch being averaged in a different order leading to a 1e-6 error that then gets amplified later in the network?

Well btw it happens for the max pooling. The pooling type flag is just for the last pooling (but the error arises before)
Given the fact the the same machine processes all the elements I would expect the output to be identical as it happens for cpu. Btw the error is not happening with pytorch 1.5. But yes, that small error is later on amplified so that the output is substantially different for the same input.

I’m still trying to reduce the minimal example

  Python VERSION: 3.6.9 (default, Oct  8 2020, 12:12:24)
[GCC 8.4.0]
	  pyTorch VERSION: 1.5.0
	  CUDA VERSION: 10.2
	  CUDNN VERSION: 7605
	  Number CUDA Devices: 3
	
	Active CUDA Device: GPU 0
	Available devices 3
	Current cuda device 0
	
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cuda:0
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 128, 33]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 122, 31]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 116, 29]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([400, 128, 1, 29]), all equal: True, max error: abs 0,mean 0
Test: Pooling None, AvgPool. Device cuda:0, max_abs 0,max_mean 0 OK
---------------------------------------
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cuda:0
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 42, 11]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 12, 3]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 6, 1]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([400, 128, 1, 1]), all equal: True, max error: abs 0,mean 0
Test: Pooling (3, 3), AvgPool. Device cuda:0, max_abs 0,max_mean 0 OK
---------------------------------------
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cuda:1
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 128, 33]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 122, 31]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 116, 29]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([400, 128, 1, 29]), all equal: True, max error: abs 0,mean 0
Test: Pooling None, AvgPool. Device cuda:1, max_abs 0,max_mean 0 OK
---------------------------------------
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cuda:1
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 42, 11]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 12, 3]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 6, 1]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([400, 128, 1, 1]), all equal: True, max error: abs 0,mean 0
Test: Pooling (3, 3), AvgPool. Device cuda:1, max_abs 0,max_mean 0 OK
---------------------------------------
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cpu
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 128, 33]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 122, 31]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 116, 29]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([400, 128, 1, 29]), all equal: True, max error: abs 0,mean 0
Test: Pooling None, AvgPool. Device cpu, max_abs 0,max_mean 0 OK
---------------------------------------
Unrolled shape: torch.Size([16, 25])
Input--> Shape: torch.Size([16, 25, 4480]), device:cpu
FFT --> Shape: torch.Size([16, 25, 257, 35])
view --> Shape: torch.Size([400, 1, 257, 35]), all equal: True, max error: abs 0,mean 0
b1 --> Shape: torch.Size([400, 32, 42, 11]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([400, 64, 12, 3]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([400, 128, 6, 1]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([400, 128, 1, 1]), all equal: True, max error: abs 0,mean 0
Test: Pooling (3, 3), AvgPool. Device cpu, max_abs 0,max_mean 0 OK
---------------------------------------
Process finished with exit code 0

Seems reshaping (ravel/unravel) is one of the reasons. Works nice without it.

import torch
import sys
import subprocess

from torch import nn
from torchaudio.transforms import MelSpectrogram, Spectrogram
import torch

N_FFT = 512
N_MELS = 256
HOP_LENGTH = 130
AUDIO_FRAMERATE = 16000


def get_sys_info():
    """

    :param log: Logging logger in which to parse info
    :type log: logging.logger
    :return: None
    """
    result = subprocess.Popen(["nvidia-smi", "--format=csv",
                               "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free"],
                              stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    nvidia = result.stdout.readlines().copy()
    nvidia = [str(x) for x in nvidia]
    nvidia = [x[2:-3] + '\r\t' for x in nvidia]
    acum = ''
    for x in nvidia:
        acum = acum + x

    return ('  Python VERSION: {0} \n\t'
            '  pyTorch VERSION: {1} \n\t'
            '  CUDA VERSION: {2}\n\t'
            '  CUDNN VERSION: {3} \n\t'
            '  Number CUDA Devices: {4} \n\t'
            '  Devices: {5}\n\t'
            'Active CUDA Device: GPU {6} \n\t'
            'Available devices {7} \n\t'
            'Current cuda device {8} \n\t'.format(sys.version, torch.__version__, torch.version.cuda,
                                                  torch.backends.cudnn.version(), torch.cuda.device_count(),
                                                  acum, torch.cuda.current_device(), torch.cuda.device_count(),
                                                  torch.cuda.current_device()))


def make_audio_block(filter_in, filters_out, lrn, max_pool=None, padding=0, stride=1, kernel_size=(3, 3)):
    layers = [nn.Conv2d(filter_in, filters_out, kernel_size=kernel_size, padding=padding, stride=stride)]
    layers.append(nn.ReLU(False))
    if max_pool is not None:
        layers.append(nn.MaxPool2d(max_pool))
    return nn.Sequential(*layers)


def reshape(x, unrolled_shape):
    return x.view(*unrolled_shape, *x.shape[1:])


def check(x):
    x_unrolled = x
    ref = x_unrolled[0]
    all_equal = True
    error_abs = [None for _ in range(x.shape[0])]
    error_mean = [None for _ in range(x.shape[0])]
    error_abs[0] = 0
    error_mean[0] = 0
    for i in range(1, x.shape[0]):
        all_equal = all_equal and torch.allclose(ref, x_unrolled[i])
        diff = torch.abs(ref - x_unrolled[i])
        diff = diff[diff > 0]
        error_abs[i] = diff.sum()
        error_mean[i] = diff.mean()
    return all_equal, max(error_abs), max(error_mean)


class AudioEncoder(nn.Module):
    # 'filter_size': [96, 256, 512, 512*6*6]
    def __init__(self, pooling, pooling_type='AvgPool'):
        super(AudioEncoder, self).__init__()
        assert pooling_type in ['MaxPool', 'AvgPool'], f'Pooling of type{pooling_type} should be MaxPool or AvgPool'
        filters = [1, 32, 64, 128]
        # self.preproc = MelSpectrogram(sample_rate=AUDIO_FRAMERATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS)
        self.preproc = Spectrogram(n_fft=N_FFT, hop_length=HOP_LENGTH)
        self.b1 = make_audio_block(filters[0], filters[1], lrn=False, max_pool=pooling, padding=2, kernel_size=(7, 7),
                                   stride=(2, 1))
        self.b2 = make_audio_block(filters[1], filters[2], lrn=False, max_pool=pooling, padding=0, kernel_size=(7, 3))
        self.b3 = make_audio_block(filters[2], filters[3], lrn=False, padding=0, kernel_size=(7, 3))
        if pooling_type == 'MaxPool':
            self.pooling = nn.AdaptiveMaxPool2d((1, None))
        else:
            self.pooling = nn.AdaptiveAvgPool2d((1, None))

    def forward(self, x):
        verbose = True
        if verbose:
            print(f'Input--> Shape: {x.shape}, device:{x.device}')
        x = self.preproc(x).unsqueeze(1)
        if verbose:
            print(f'FFT --> Shape: {x.shape}')
        x = self.b1(x)
        if verbose:
            equal, abs_max, mean_max = check(x)
            print(f'b1 --> Shape: {x.shape}, all equal: {equal}, max error: abs {abs_max},mean {mean_max}')
        x = self.b2(x)
        if verbose:
            equal, abs_max, mean_max = check(x)
            print(f'b2 --> Shape: {x.shape}, all equal: {equal}, max error: abs {abs_max},mean {mean_max}')
        x = self.b3(x)
        if verbose:
            equal, abs_max, mean_max = check(x)
            print(f'b3 --> Shape: {x.shape}, all equal: {equal}, max error: abs {abs_max},mean {mean_max}')
        x = self.pooling(x)
        if verbose:
            equal, abs_max, mean_max = check(x)
            print(f'pooling --> Shape: {x.shape}, all equal: {equal}, max error: abs {abs_max},mean {mean_max}')
        x = x.squeeze()
        return x, x.shape


BATCH_SIZE = 16


@torch.no_grad()
def run_test(pooling, pooling_type, device):
    device = torch.device(device)
    model = AudioEncoder(pooling, pooling_type).to(device)
    inp_element = torch.rand(4480).to(device)
    inp = torch.stack([inp_element.clone() for _ in range(BATCH_SIZE)])

    # print(model)

    y, shape = model(inp)
    is_identical, max_abs, max_mean = check(y)

    if is_identical:
        print(f"Test: Pooling {pooling}, {pooling_type}. Device {device}, max_abs {max_abs},max_mean {max_mean} OK")
    else:
        print(
            f"Test: Pooling {pooling}, {pooling_type}. Device {device}, max_abs {max_abs},max_mean {max_mean}. Failed")
    print('---------------------------------------')


pooling_tests = [None, (3, 3)]
pooling_types = ['AvgPool']
devices = ['cuda:0', 'cuda:1', 'cpu']

if __name__ == '__main__':

    print(get_sys_info())
    for device in devices:
        for pooling_i in pooling_tests:
            for pooling_type_i in pooling_types:
                run_test(pooling_i, pooling_type_i, device)
  Python VERSION: 3.6.9 (default, Oct  8 2020, 12:12:24) 
[GCC 8.4.0] 
	  pyTorch VERSION: 1.7.0+cu110 
	  CUDA VERSION: 11.0
	  CUDNN VERSION: 8004 
	  Number CUDA Devices: 3 
	
	Active CUDA Device: GPU 0 
	Available devices 3 
	Current cuda device 0 
	
Input--> Shape: torch.Size([16, 4480]), device:cuda:0
/home/jfm/.local/lib/python3.6/site-packages/torch/functional.py:516: UserWarning: stft will require the return_complex parameter be explicitly  specified in a future PyTorch release. Use return_complex=False  to preserve the current behavior or return_complex=True to return  a complex output. (Triggered internally at  /pytorch/aten/src/ATen/native/SpectralOps.cpp:653.)
  normalized, onesided, return_complex)
/home/jfm/.local/lib/python3.6/site-packages/torch/functional.py:516: UserWarning: The function torch.rfft is deprecated and will be removed in a future PyTorch release. Use the new torch.fft module functions, instead, by importing torch.fft and calling torch.fft.fft or torch.fft.rfft. (Triggered internally at  /pytorch/aten/src/ATen/native/SpectralOps.cpp:590.)
  normalized, onesided, return_complex)
FFT --> Shape: torch.Size([16, 1, 257, 35])
b1 --> Shape: torch.Size([16, 32, 128, 33]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([16, 64, 122, 31]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([16, 128, 116, 29]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([16, 128, 1, 29]), all equal: True, max error: abs 0,mean 0
Test: Pooling None, AvgPool. Device cuda:0, max_abs 0,max_mean 0 OK
---------------------------------------
Input--> Shape: torch.Size([16, 4480]), device:cuda:0
FFT --> Shape: torch.Size([16, 1, 257, 35])
b1 --> Shape: torch.Size([16, 32, 42, 11]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([16, 64, 12, 3]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([16, 128, 6, 1]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([16, 128, 1, 1]), all equal: True, max error: abs 0,mean 0
Test: Pooling (3, 3), AvgPool. Device cuda:0, max_abs 0,max_mean 0 OK
---------------------------------------
Input--> Shape: torch.Size([16, 4480]), device:cuda:1
FFT --> Shape: torch.Size([16, 1, 257, 35])
b1 --> Shape: torch.Size([16, 32, 128, 33]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([16, 64, 122, 31]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([16, 128, 116, 29]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([16, 128, 1, 29]), all equal: True, max error: abs 0,mean 0
Test: Pooling None, AvgPool. Device cuda:1, max_abs 0,max_mean 0 OK
---------------------------------------
Input--> Shape: torch.Size([16, 4480]), device:cuda:1
FFT --> Shape: torch.Size([16, 1, 257, 35])
b1 --> Shape: torch.Size([16, 32, 42, 11]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([16, 64, 12, 3]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([16, 128, 6, 1]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([16, 128, 1, 1]), all equal: True, max error: abs 0,mean 0
Test: Pooling (3, 3), AvgPool. Device cuda:1, max_abs 0,max_mean 0 OK
---------------------------------------
Input--> Shape: torch.Size([16, 4480]), device:cpu
FFT --> Shape: torch.Size([16, 1, 257, 35])
b1 --> Shape: torch.Size([16, 32, 128, 33]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([16, 64, 122, 31]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([16, 128, 116, 29]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([16, 128, 1, 29]), all equal: True, max error: abs 0,mean 0
Test: Pooling None, AvgPool. Device cpu, max_abs 0,max_mean 0 OK
---------------------------------------
Input--> Shape: torch.Size([16, 4480]), device:cpu
FFT --> Shape: torch.Size([16, 1, 257, 35])
b1 --> Shape: torch.Size([16, 32, 42, 11]), all equal: True, max error: abs 0,mean 0
b2 --> Shape: torch.Size([16, 64, 12, 3]), all equal: True, max error: abs 0,mean 0
b3 --> Shape: torch.Size([16, 128, 6, 1]), all equal: True, max error: abs 0,mean 0
pooling --> Shape: torch.Size([16, 128, 1, 1]), all equal: True, max error: abs 0,mean 0
Test: Pooling (3, 3), AvgPool. Device cpu, max_abs 0,max_mean 0 OK
---------------------------------------

Process finished with exit code 0

Given the fact the the same machine processes all the elements I would expect the output to be identical as it happens for cpu.

Don’t the cpu tests above actually give the same result? Only the cuda version shows differences right?
Also you can try to set OMP_NUM_THREADS=1 (or corresponding MKL version if you use that)to avoid non-determinism on the CPU.

Seems reshaping (ravel/unravel) is one of the reasons. Works nice without it.

What did you change exactly in the code?

Okay, let me introduce it again.
This is the simplest case I could find.
The idea is to ravel/unravel temporal 5-dimensional tensors to process one of them in the batch dimension.
It’s basically the following ops:
View-->conv2d-->view
Only cuda shows differences.
Torch 1.6 (cuda 10.__) is ok. Torch 1.7 (cuda 11)is not.
Soo it seems that setting

torch.backends.cudnn.deterministic = True
torch.set_deterministic(True)

Solves the issue. However it wasn’t necessary in Pytorch 1.6 compared to Pytorch 1.7.
Do you know what changed?

import torch
import sys
import subprocess

from torch import nn
import torch


def get_sys_info():
    """

    :param log: Logging logger in which to parse info
    :type log: logging.logger
    :return: None
    """
    result = subprocess.Popen(["nvidia-smi", "--format=csv",
                               "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free"],
                              stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    nvidia = result.stdout.readlines().copy()
    nvidia = [str(x) for x in nvidia]
    nvidia = [x[2:-3] + '\r\t' for x in nvidia]
    acum = ''
    for x in nvidia:
        acum = acum + x

    return ('  Python VERSION: {0} \n\t'
            '  pyTorch VERSION: {1} \n\t'
            '  CUDA VERSION: {2}\n\t'
            '  CUDNN VERSION: {3} \n\t'
            '  Number CUDA Devices: {4} \n\t'
            '  Devices: {5}\n\t'
            'Active CUDA Device: GPU {6} \n\t'
            'Available devices {7} \n\t'
            'Current cuda device {8} \n\t'.format(sys.version, torch.__version__, torch.version.cuda,
                                                  torch.backends.cudnn.version(), torch.cuda.device_count(),
                                                  acum, torch.cuda.current_device(), torch.cuda.device_count(),
                                                  torch.cuda.current_device()))




def reshape(x, unrolled_shape):
    return x.view(*unrolled_shape, *x.shape[1:])


def check(x_unrolled,unrolled_shape):
    ref = x_unrolled[0]
    all_equal = True
    error_abs = [None for _ in range(unrolled_shape[0])]
    error_mean = [None for _ in range(unrolled_shape[0])]
    error_abs[0] = 0
    error_mean[0] = 0
    for i in range(1, unrolled_shape[0]):
        all_equal = all_equal and torch.allclose(ref, x_unrolled[i])
        diff = torch.abs(ref - x_unrolled[i])
        diff = diff[diff > 0]
        error_abs[i] = diff.sum()
        error_mean[i] = diff.mean()
    return all_equal, max(error_abs), max(error_mean)


class Toy(nn.Module):
    def __init__(self):
        super(Toy, self).__init__()
        self.conv = nn.Conv2d(1,128, padding=0, kernel_size=(7, 3))

    def forward(self, x):
        unraveled_shape = x.shape[:2]
        print(f'Input--> Shape: {x.shape}, device:{x.device}')

        x = x.view(-1, 1, 12,3).contiguous()
        undo = x.view(*unraveled_shape, 12,3).contiguous()
        print(f'Raveled View OP --> Shape: {x.shape}')
        print(f'Unraveled View OP --> Shape: {undo.shape}')

        equal, abs_max, mean_max = check(undo, unraveled_shape)
        print(f'View Results: Are equal: {equal}, Max. Abs. Diff: {abs_max}, Mean Abs. Diff: {mean_max}')

        x = self.conv(x).contiguous()
        undo = x.view(*unraveled_shape, 128, 6,1).contiguous()
        print(f'Raveled Conv2D OP --> Shape: {x.shape}')
        print(f'UnRaveled Conv2D OP --> Shape: {undo.shape}')

        equal, abs_max, mean_max = check(undo, unraveled_shape)
        print(f'View Results: Are equal: {equal}, Max. Abs. Diff: {abs_max}, Mean Abs. Diff: {mean_max}')
        return x


BATCH_SIZE = 16


@torch.no_grad()
def run_test( device):
    device = torch.device(device)
    model = Toy().to(device)
    inp_element = torch.rand(25, 12,3).to(device)
    inp = torch.stack([inp_element.clone() for _ in range(BATCH_SIZE)])

    # print(model)

    y = model(inp)
    print('---------------------------------------')


devices = ['cuda:0', 'cuda:1', 'cpu']

if __name__ == '__main__':

    print(get_sys_info())
    for device in devices:
        run_test( device)

REPORT FOR TORCH 1.6

 Python VERSION: 3.6.9 (default, Oct  8 2020, 12:12:24)
[GCC 8.4.0]
	  pyTorch VERSION: 1.6.0
	  CUDA VERSION: 10.2
	  CUDNN VERSION: 7605
	  Number CUDA Devices: 3
	
	Active CUDA Device: GPU 0
	Available devices 3
	Current cuda device 0
	
Input--> Shape: torch.Size([16, 25, 12, 3]), device:cuda:0
Raveled View OP --> Shape: torch.Size([400, 1, 12, 3])
Unraveled View OP --> Shape: torch.Size([16, 25, 12, 3])
View Results: Are equal: True, Max. Abs. Diff: 0, Mean Abs. Diff: 0
Raveled Conv2D OP --> Shape: torch.Size([400, 128, 6, 1])
UnRaveled Conv2D OP --> Shape: torch.Size([16, 25, 128, 6, 1])
View Results: Are equal: True, Max. Abs. Diff: 0, Mean Abs. Diff: 0
---------------------------------------
Input--> Shape: torch.Size([16, 25, 12, 3]), device:cuda:1
Raveled View OP --> Shape: torch.Size([400, 1, 12, 3])
Unraveled View OP --> Shape: torch.Size([16, 25, 12, 3])
View Results: Are equal: True, Max. Abs. Diff: 0, Mean Abs. Diff: 0
Raveled Conv2D OP --> Shape: torch.Size([400, 128, 6, 1])
UnRaveled Conv2D OP --> Shape: torch.Size([16, 25, 128, 6, 1])
View Results: Are equal: True, Max. Abs. Diff: 0, Mean Abs. Diff: 0
---------------------------------------
Input--> Shape: torch.Size([16, 25, 12, 3]), device:cpu
Raveled View OP --> Shape: torch.Size([400, 1, 12, 3])
Unraveled View OP --> Shape: torch.Size([16, 25, 12, 3])
View Results: Are equal: True, Max. Abs. Diff: 0, Mean Abs. Diff: 0
Raveled Conv2D OP --> Shape: torch.Size([400, 128, 6, 1])
UnRaveled Conv2D OP --> Shape: torch.Size([16, 25, 128, 6, 1])
View Results: Are equal: True, Max. Abs. Diff: 0, Mean Abs. Diff: 0
---------------------------------------

REPORT FOR PYTORCH 1.7

[GCC 8.4.0] 
	  pyTorch VERSION: 1.7.0+cu110 
	  CUDA VERSION: 11.0
	  CUDNN VERSION: 8004 
	  Number CUDA Devices: 3 
	
	Active CUDA Device: GPU 0 
	Available devices 3 
	Current cuda device 0 
	
Input--> Shape: torch.Size([16, 25, 12, 3]), device:cuda:0
Raveled View OP --> Shape: torch.Size([400, 1, 12, 3])
Unraveled View OP --> Shape: torch.Size([16, 25, 12, 3])
View Results: Are equal: True, Max. Abs. Diff: 0, Mean Abs. Diff: 0
Raveled Conv2D OP --> Shape: torch.Size([400, 128, 6, 1])
UnRaveled Conv2D OP --> Shape: torch.Size([16, 25, 128, 6, 1])
View Results: Are equal: False, Max. Abs. Diff: 0.00018408219330012798, Mean Abs. Diff: 2.6632262617454217e-08
---------------------------------------
Input--> Shape: torch.Size([16, 25, 12, 3]), device:cuda:1
Raveled View OP --> Shape: torch.Size([400, 1, 12, 3])
Unraveled View OP --> Shape: torch.Size([16, 25, 12, 3])
View Results: Are equal: True, Max. Abs. Diff: 0, Mean Abs. Diff: 0
Raveled Conv2D OP --> Shape: torch.Size([400, 128, 6, 1])
UnRaveled Conv2D OP --> Shape: torch.Size([16, 25, 128, 6, 1])
View Results: Are equal: False, Max. Abs. Diff: 0.0001880067866295576, Mean Abs. Diff: 2.7989695894348188e-08
---------------------------------------
Input--> Shape: torch.Size([16, 25, 12, 3]), device:cpu
Raveled View OP --> Shape: torch.Size([400, 1, 12, 3])
Unraveled View OP --> Shape: torch.Size([16, 25, 12, 3])
View Results: Are equal: True, Max. Abs. Diff: 0, Mean Abs. Diff: 0
Raveled Conv2D OP --> Shape: torch.Size([400, 128, 6, 1])
UnRaveled Conv2D OP --> Shape: torch.Size([16, 25, 128, 6, 1])
View Results: Are equal: True, Max. Abs. Diff: 0, Mean Abs. Diff: 0
---------------------------------------

Process finished with exit code 0

It would be very helpful if you could get the cudnn versions to match.
Because the fact that setting torch.backends.cudnn.deterministic = True seems to hint that something changed on the cudnn side.
In particular, for the old version, it was picking a deterministic algorithm by chance which is not the case for the new version?