What exactly does torch.backends.cudnn.deterministic = True do?

I was doing an unsupervised experiment where some parts of conv weights are masked so does a position in output sees only specific input pixels. While debugging the model all the output positions were depending on input pixels that they shouldn’t(found by backpropagating gradients). I made sure all masking is correct. Finally after hours of debugging when torch.backends.cudnn.deterministic = True was set, all the output pixels were looking only at correct input positions. How is this possible? Does cudnn uses optimizations that may alter the gradients during backprop?

Could you please post a code snippet to reproduce this issue?

import numpy as np
import torch
import torch.nn.functional as F
import math
import torch.nn as nn
use_cuda = True
torch.backends.cudnn.deterministic = True

torch.manual_seed(0)
hidden_dims = 32
num_layers = 12
num_classes = 4
num_val = 1000
batch_size = 128
if use_cuda:
    device = "cuda"
else:
    device = "cpu"

def calc_pad(image_size, kernel_size, stride=1):
    pad_size = ((stride * (image_size-1)) + kernel_size - image_size) //2
    return (pad_size, pad_size)

class masked_CNN_typeA(nn.Conv2d):
    def __init__(self, mask_type, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, tensor_device="cpu"):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride, 
                         padding=padding, dilation=dilation, groups=groups, bias=bias)
        assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1, "provide odd value for kernel h and w"
        self.mask_type = mask_type.lower()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.tensor_device = tensor_device
        self.register_buffer("mask", self.get_mask())
        
    def get_mask(self):
        k_h = self.kernel_size[0]
        k_w = self.kernel_size[1]
        centre_h = k_h // 2
        centre_w = k_w // 2
        mask = torch.ones((self.out_channels, self.in_channels,k_h, k_w),dtype=torch.float32, 
                          device=self.tensor_device)
        mask[:, :, centre_h+1:, :] = 0
        mask[:, :, centre_h, centre_w+1:] = 0
        if self.mask_type == "a":
            mask[:, :, centre_h, centre_w] = 0
        return mask
    
    def forward(self, inp):
        weight = self.mask * self.weight
        return F.conv2d(inp, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

class Pixel_CNN_typeA(nn.Module):
    def __init__(self, hidden_dims, num_classes, num_layers, input_shape=(3,28,28)):
        super().__init__()
        self.hidden_dims = hidden_dims
        self.num_classes = num_classes
        self.num_layers = num_layers
        self.input_shape = input_shape
        self.net = nn.ModuleList(self.build_network())
        
    def build_network(self):
        layers = []
        initial_conv = nn.Sequential(masked_CNN_typeA("A", self.input_shape[0], self.input_shape[0], kernel_size=(7,7),
                                padding=calc_pad(28, 7), bias=True, tensor_device=device),
                       nn.ReLU(),
                       nn.Conv2d(self.input_shape[0], self.hidden_dims*2, kernel_size=(1,1), 
                                bias=True),
                       nn.ReLU()
                       )
        layers.append(initial_conv)
        for i in range(self.num_layers):
            layers.append(nn.Sequential(nn.Conv2d(self.hidden_dims*2, self.hidden_dims,  kernel_size=(1,1),
                                        bias=True),
                          nn.ReLU(),
                          masked_CNN_typeA("B", self.hidden_dims, self.hidden_dims, kernel_size=(3,3),
                                    padding=calc_pad(28,3), bias=True),
                          nn.ReLU(),
                          nn.Conv2d(self.hidden_dims, self.hidden_dims*2, kernel_size=(1,1), 
                                   bias=True),
                          nn.ReLU()
                        ))
        final_conv = nn.Sequential(nn.Conv2d(self.hidden_dims*2, self.hidden_dims, kernel_size=(1,1),
                                            bias=True),
                                   nn.Conv2d(self.hidden_dims, self.input_shape[0]*self.num_classes, kernel_size=(1,1),
                                            bias=True)
                                  )
        layers.append(final_conv)
        return layers
    
    def forward(self, inp, mode="train"):
        out = self.net[0](inp)
        residual = out
        for i in range(self.num_layers):
            out = self.net[i+1](out)
            out += residual
            residual = out
        return out

model_1 = Pixel_CNN_typeA(hidden_dims, num_classes, num_layers)
if use_cuda:
    model_1 = model_1.cuda()

inw = torch.ones(1,3,28,28).type(torch.float32).cuda()
inw.requires_grad = True
temp_prob = model_1(inw)
temp_prob_slice = temp_prob[:,:,7,7]
temp_prob_slice.backward(torch.tensor([10]).repeat(64).view(1,-1).type(torch.cuda.FloatTensor))
print(inw.grad[:,0,7,:])

So Iam getiing the output (1,3,28,28) and backpropagating only through (:,:,7,7) which should affect input pixels before (:,:,7,7). when torch.backends.cudnn.deterministic = True is set it performs correctly, but when not it depends on future.
Results when torch.backends.cudnn.deterministic=True,
tensor([[ 1.2851e-05, 2.1000e-04, 5.1798e-03, 1.3240e-01, 1.4257e+00,
-1.1997e+00, 2.2698e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00]], device=‘cuda:0’)

Results without setting it,
tensor([[ 1.2851e-05, 2.1000e-04, 5.1798e-03, 1.3240e-01, 1.4257e+00,
-1.1997e+00, 2.2698e-02, -3.8616e-10, 8.7847e-12, -1.1786e-19,
3.3309e-21, 1.3479e-29, -1.7327e-31, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00]], device=‘cuda:0’)

The extra values that peak up are super tiny but is cudnn supposed to do this?