Inconsistency During Inference

I have trained several models and would like to compare their performance on a single image. The problem is that running the same model (let’s call it A) produces different results. I have set the model to evaluation mode (i.e. A = myModel.eval()) and I am using “with torch.no_grad()” yet every time I apply the model to my test image I get a different result. (I have several models but here I am demonstrating the problem with only one model!)


A = myModel.eval()

modelList = [A] # the list of models contains more models but I am just showing one here!

# looping over test images (here only one image)
for img in train_loader:
    with torch.no_grad():
        for models in modelList:
            models.to(device)
            img.to(device)
            modelResult = models(img)
            print(modelResult.mean())

This print line prints out different results every time I run the loop but I expect to get the same results every time since the model weights are the same and the image is the same image!

I’d appreciate your help,
Thank you

Could you check the transformations applied to the images in the Dataset and make sure no random transforms are used?
If that’s already the case, could you check the reproducibility docs and e.g. make sure that cudnn picks deterministic algorithms, in case you are seeing relative errors of ~1e-6.

Regarding your first point, I am using the following line of code to generate my Dataloader and I am not using any transforms.

train_loader = torch.utils.data.DataLoader(outList, batch_size=1, shuffle=False, num_workers=0)

Please note that there is ONLY ONE image inside the outList. As for the reproducibility docs, given the size of the relative error you mentioned (~1e-6.), I do not think that is the case since my errors are bigger than that! To give you some perspective (and make sure that my query is clear enough) I would like to add that the for loop above runs ONLY ONCE and it exits the loop since there is only ONE test image inside the Dataloader. Moreover, the trained/loaded weight, bias, and batch values are constant through different runs (something we expect). It’s just the output of the model that contains different pixel values after each run. I am running the code snippet above inside a Jupyter notebook so every time I hit shift+enter to run the cell it generates a different set of results for the same single input image! Below you can see a small slice of my output after feeding the same single image in the dataloader to my trained model (i.e. just hitting shift+enter three times to run the loop above)

First run:

tensor([[ 0.6058, 0.2320, 0.0092],
[ 0.1484, -0.2899, -0.5605],
[ 0.0013, -0.4538, -0.7382]], device=‘cuda:0’)

Second run:

tensor([[-0.1410, -0.4058, -0.3253],
[-0.6349, -0.8924, -0.7446],
[-0.9760, -1.2406, -1.0597]], device=‘cuda:0’)

Third run:
tensor([[ 0.3299, 0.0792, 0.0058],
[-0.1715, -0.3860, -0.4168],
[-0.5801, -0.7737, -0.7866]], device=‘cuda:0’)

So the difference between different runs are significantly larger than ~1e-6. Also, another equally concerning error is the change of sign between different runs! Please let me know if I need to provide you with more details.

I forgot to mention that I am training and testing my model on two different machines. According to reproducibility docs I used both torch. "set_deterministic (d = True ) "and “torch.backends.cudnn.benchmark = False ” to force Pytorch to pick a deterministic set of algorithms but it does not seem to make any difference! To be honest I am not sure if I used these commands correctly as I just threw them into my code right before the for loop!

Assuming you have made sure no random transformations or other operations are applied, could you post the model definition as well as the input shapes, which would reproduce the different outputs?
Since the outputs are large, it’s not due to non-deterministic algorithms, but it seems some random ops are still enabled in your script, which we would need to debug.

I am sure there is no random transformations involved since I am not applying any transformations to my input data (i.e. img). It is a simple numpy.ndarray of the shape [2, 58, 72] that is passed to torch.utils.data.DataLoader without going through any transformations. Also, printing out a small slice of the img in the loop above will always yield the same pixel values no matter how many times I run the loop/cell (something I expect). I am not sure if it is important but the dataset used for training the model is a set of “numpy.ndarray” s. I did not convert them to Pytorch tensors!

I do agree that this has to be the case and I would like to attract your attention to the point that printing out the auto_grad status of the different layers of the model using “.named_parameters()” and then applying “.requires_grad” shows “requires_grad=True”. Given that the whole block is inside “with torch.no_grad():” I expected to get “requires_grad=False”! Is this an expected behaviour? Could it be that the gradient operation is still in action?

The input used for training has the shape of [2, 58, 72] where 2, 58 and 72 are number of channels, height, and width respectively.


import numpy
import torch
import torch.nn as nn
import torch.optim
import torchvision


class Input(nn.Module):
    def __init__(self, in_ch = [2, 64, 64], out_ch = [64, 64, 64], kernel = [(4, 5), (3, 5), (3 ,5)], \
                         pad = [(4 ,0), (1, 0), (1, 0)], stride = [(2, 1), (1, 1), (1, 1)]):
        
        self.input_ch   = in_ch
        super().__init__()

        # Adjusting the dimensions of the low resolution input!
        self.convList   = nn.ModuleList([nn.Conv2d(in_ch[i], out_ch[i], kernel_size =  kernel[i], \
                            stride = stride[i], padding = [i]) for i in range(len(kernel))])
        self.batchList  = nn.ModuleList([nn.BatchNorm2d(out_ch[j]) for j in range(len(in_ch))])
       # Independent Section!
        self.active      = nn.LeakyReLU()
    def forward(self, lowOut):
        ### The following "bicubic" interpolation will be used at the very end!!!
        bicubicInterpol  = nn.functional.interpolate(lowOut, size = (144, 180), mode = "bicubic")
        ###
        for count in range(len(self.input_ch)):
            lowOut = self.convList[count](lowOut.cuda("cuda:0"))
            lowOut = self.batchList[count](lowOut.cuda("cuda:0"))
            lowOut = self.active(lowOut)
        
        return lowOut, bicubicInterpol

class Upscale(nn.Module):
    """This function upsamples an input tensor to the spatial dimensions of the target tensor using a 
        combination of bilinear interpolation and convolution."""
    def __init__(self, targetFeature = (64, 144, 180)):
        super().__init__()
        
        self.firstUpsample = nn.Conv2d(64, 64, 3, padding = 1)
        self.targetFeature = targetFeature
    def forward(self, inputT):
        inputT_ch = inputT.shape[1]
        inputT_H  = inputT.shape[2]
        inputT_W  = inputT.shape[3]
        targetT_ch= self.targetFeature[0]
        targetT_H = self.targetFeature[1]
        targetT_W = self.targetFeature[2]
        firstInterpol  = nn.functional.interpolate(inputT, size = (targetT_H, targetT_W), mode = "bilinear")
        upsampleOne    = self.firstUpsample(firstInterpol)
        return upsampleOne


class Residual(nn.Module):
    def __init__(self, sameChannel = 128): # default initialization to 128 (from the first layer of the Encoding)
        super().__init__()
        self.resConv = nn.Conv2d(sameChannel, sameChannel, 3, padding = 1).cuda('cuda:0')
    def forward(self, inputT):
        inputT_ch = inputT.shape[1]
        inputT_H  = inputT.shape[2]
        inputT_W  = inputT.shape[3]

        out       = self.resConv(inputT)
        out       = self.resConv(out)
        out       = self.resConv(out)

        addition = out + inputT

        return addition


class Encoding(nn.Module):
    def __init__(self, channel = [64, 128, 192, 256, 320, 384, 448], stride = [(2, 1), (1, 3), \
                                (2, 1), (3, 3), (2, 2), (2, 2)]):
        super().__init__()
        self.channel    = channel
        self.enConvList = nn.ModuleList([nn.Conv2d(channel[i], channel[i + 1], 3, stride = stride[i], \
                          padding = 1) for i in range(len(channel) - 1)])
        self.batchList  = nn.ModuleList([nn.BatchNorm2d(channel[j  + 1]) for j in range(len(channel) - 1)])
        self.active     = nn.LeakyReLU()
        
    def forward(self, out):
        concatList = []
        for i in range(len(self.channel) - 1):
            concatList.append(out)
            ### strided convolution block
            out      = self.enConvList[i](out)
            ### batch normalization block
            out      = self.batchList[i](out)
            ### activation block
            out      = self.active(out)
            ### residual block
            residual = Residual(out.shape[1]) # 1 to get the channel number
            out      = residual(out)
        return out, concatList


class Decoding(nn.Module):
    def __init__(self, upChannel_in  = [448, 384, 320, 256, 192, 128],\
                 upChannel_out       = [384, 320, 256, 192, 128, 64] ,\
                 deconvChannel2_in   = [768, 640, 512, 384, 256, 128],\
                 deconvChannel2_out  = [384, 320, 256, 192, 128, 64] ,\
                 upKernel  = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 5), (3, 3)], \
                 upPadding = [1, 1, 1, 1, (1, 2), 1], finalIn = 64, finalOut = 2):
        
        super().__init__()
        self.residual   = Residual()
        self.channelLength = upChannel_in
        self.deConvList = nn.ModuleList([nn.Conv2d(deconvChannel2_in[i], deconvChannel2_out[i], 3, padding = 1) \
                                         for i in range(len(upChannel_in))])
        self.batch      = nn.ModuleList([nn.BatchNorm2d(deconvChannel2_out[k]) for k in range(len(deconvChannel2_out))])
        self.active     = nn.LeakyReLU()
    
        ### Upsampling block preparation ###
        self.upConv     = nn.ModuleList([nn.Conv2d(upChannel_in[i], upChannel_out[i], kernel_size = upKernel[i], \
                            padding = upPadding[i], padding_mode = "replicate") for i in range(len(upChannel_out))])
        self.upBatch    = nn.ModuleList([nn.BatchNorm2d(upChannel_out[j]) for j in range(len(upChannel_out))])
        self.upActive   = nn.LeakyReLU()
        
        ### Last deconvolution layer ###
        self.lastConv    = nn.Conv2d(finalIn, finalOut, 3, padding = 1, padding_mode = "replicate")
    
    def interpolation(self, inputT, targetT):
        """This function upsamples an input tensor to the spatial dimensions of the target tensor using a 
        combination of bilinear interpolation and convolution."""
        inputT_ch = inputT.shape[1]
        inputT_H  = inputT.shape[2]
        inputT_W  = inputT.shape[3]
        
        targetT_ch= targetT.shape[1]
        targetT_H = targetT.shape[2]
        targetT_W = targetT.shape[3]
             
        interpol  = nn.functional.interpolate(inputT, size = (targetT_H, targetT_W), mode = "bilinear")
        return interpol
    
    def forward(self, out, enList):
        enList     = enList[: : -1] # reversing the order of elements inside the list.
        
        for l in range(len(self.channelLength)):
            
            ### Upscaling = interpolation + convolution ###
            out    = self.interpolation(out, enList[l])
            out    = self.upConv[l](out)
            out    = self.upBatch[l](out)
            out    = self.upActive(out)
            ### Concatenation ###
            out    = torch.cat([out, enList[l]], 1) # concatenation on channel number!
            ### Decoding ###
            out    = self.deConvList[l](out)
            out    = self.batch[l](out)
            out    = self.active(out)
            ### Residual block
            residual = Residual(out.shape[1]) # 1 to get the channel number
            out      = residual(out)
            
        ### This is the last convolution layer!
        out    = self.lastConv(out)
        return out


class DeepTest(nn.Module):
    def __init__(self):
        super().__init__()
        self.input       = Input()
        self.upscale     = Upscale()
        self.encoder     = Encoding()
        self.decoder     = Decoding()
        

    def forward(self, lowRes):
        out, bicub  = self.input(lowRes)
        out         = self.upscale(out)
        out, enlist = self.encoder(out)
        out         = self.decoder(out, enlist)
        finalOut    = out + bicub
        return finalOut

And I use the following code snippet to initiate the training process:


import datetime
from torch import optim

device    = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

torch.cuda.empty_cache()

n_epochs  = 100
loss_fn   = torch.nn.MSELoss()
model     = DeepTest().to(device)
lrate     = 1e-3
optimizer = optim.Adam(model.parameters(), weight_decay = 1e-4)

train_loader = torch.utils.data.DataLoader(outList, batch_size=20, shuffle=True, num_workers=0)
for epoch in range(1, n_epochs + 1):
    loss_train = 0.0
    for img, label in train_loader:
        img    = img.to(device)
        label  = label.to(device)
        output = model(img) # here lebel serves as the second input to our model!
        loss   = loss_fn(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_train += loss.item()

    if epoch == 1 or epoch % 10 == 0:
        print('{} Epoch {}, Training loss {}'.format(\
            datetime.datetime.now(), epoch,\
            loss_train / len(train_loader)))

That’s expected, as the with torch.no_grad() block won’t switch any attributes of the model parameters, but will instead make sure that new operations performed inside this block won’t be tracked. Also, if you don’t call backward() on the output or any loss tensor and optimizer.step(), the parameters won’t be changed through the forward pass only.

Thanks for the code.
In Encoding you are creating new, randomly initialized, Residual modules inside the forward method, which will cause the difference in the outputs.

First, I appreciate your help very much and thank you. Second, based on your response, initialisation of the “Residual” module in the “forward” method of the “Deccoding” block is also part of the problem as I am using it not only in the “Encoding” block but in the “Decoding” block too! Consequently, that means the weights and biases in the convolution layers of the “Residual” block are NEVER trained properly because they are randomly initialised every single time the “Residual” block is called/initialised! Is my understanding correct?
Finally, to fix the “Residual” block issue I need help! Should I accept your response as solution and open a new thread with a different title? Or can we continue with this thread?

Yes, your understanding is correct. The parameters of the newly created Residual modules won’t be part of the optimizer and will never be updated. Even if the optimizer would somehow update them (or you manually), you would recreate a new module in the next forward pass, which would again use random parameters.
To fix this issue you should initialize this module in the __init__ (as is done with other layers) and just use it in the forward.

If the follow-up question is related to this one, let’s keep it in this thread.

Since I am going to ask about the same model I will continue with this thread. Below, you can see the way I have fixed the “Residual” layer. (I only have provided the part that I have changed in my code.) However, I still have doubts if all parameters in the “Residual” layers are actually being trained properly! You can see more details below.

This is my fixed code (note that I commented out the “Residual” class completely so I am not showing it here).


class Encoding(nn.Module):
    def __init__(self, channel = [64, 128, 192, 256, 320, 384, 448], s\
                 tride = [(2, 1), (1, 3), (2, 1), (3, 3), (2, 2), (2, 2)], \
                 sameChannel_In = [128, 192, 256, 320, 384, 448]):

        super().__init__()
        self.channel    = channel
        self.enConvList = nn.ModuleList([nn.Conv2d(channel[i], channel[i + 1], 3, stride = stride[i], \
                          padding = 1) for i in range(len(channel) - 1)])
        self.batchList  = nn.ModuleList([nn.BatchNorm2d(channel[j  + 1]) for j in range(len(channel) - 1)])
        self.active     = nn.LeakyReLU()
        
        ### Residual layers
        self.residualEncoding = nn.ModuleList([nn.Conv2d(sameChannel_In[i], sameChannel_In[i], 
                                kernel_size = 3, stride = 1, padding = 1) 
                                for i in range(len(sameChannel_In))])
        self.residualBatch    = nn.ModuleList([nn.BatchNorm2d(sameChannel_In[i]) \
                                for i in range(len(sameChannel_In))]) 
    def forward(self, out):
        concatList = []
        for i in range(len(self.channel) - 1):
            concatList.append(out)
            ### strided convolution block
            out      = self.enConvList[i](out)
            ### batch normalization block
            out      = self.batchList[i](out)
            ### activation block
            out      = self.active(out)
            ### Residual layers
            inputToResidual = out
            ## first layer
            out  = self.residualEncoding[i](out)
            out  = self.residualBatch[i](out)
            out  = self.active(out)
            ## second layer
            out  = self.residualEncoding[i](out)
            out  = self.residualBatch[i](out)
            out  = self.active(out)
            ## third layer
            out  = self.residualEncoding[i](out)
            out  = self.residualBatch[i](out)
            out  = self.active(out)
            ## Addition of the input of the residual block to the output of the residual block!
            out  = inputToResidual + out
        return out, concatList


class Decoding(nn.Module):
    def __init__(self, upChannel_in  = [448, 384, 320, 256, 192, 128],\
                 upChannel_out       = [384, 320, 256, 192, 128, 64] ,\
                 deconvChannel2_in   = [768, 640, 512, 384, 256, 128],\
                 deconvChannel2_out  = [384, 320, 256, 192, 128, 64] ,\
                 upKernel  = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 5), (3, 3)], \
                 upPadding = [1, 1, 1, 1, (1, 2), 1], finalIn = 64, finalOut = 2,\
                 sameChannel_Out= [384, 320, 256, 192, 128, 64]):
        
        super().__init__()
        self.channelLength = upChannel_in
        self.deConvList = nn.ModuleList([nn.Conv2d(deconvChannel2_in[i], \
        deconvChannel2_out[i], 3, padding = 1) \
        for i in range(len(upChannel_in))])
        self.batch      = nn.ModuleList([nn.BatchNorm2d(deconvChannel2_out[k]) \
        for k in range(len(deconvChannel2_out))])
        self.active     = nn.LeakyReLU()
    
        ### Upsampling block preparation ###
        self.upConv     = nn.ModuleList([nn.Conv2d(upChannel_in[i], \
        upChannel_out[i], kernel_size = upKernel[i], padding = upPadding[i], \
        padding_mode = "replicate") for i in range(len(upChannel_out))])
        self.upBatch    = nn.ModuleList([nn.BatchNorm2d(upChannel_out[j]) \
        for j in range(len(upChannel_out))])

        self.upActive   = nn.LeakyReLU()
        
        ### Last deconvolution layer ###
        self.lastConv    = nn.Conv2d(finalIn, finalOut, 3, padding = 1, padding_mode = "replicate")
    
       ### Residual layers
        self.residualDecoding = nn.ModuleList([nn.Conv2d(sameChannel_Out[i], \
        sameChannel_Out[i], kernel_size = 3, stride = 1, padding = 1) \
        for i in range(len(sameChannel_Out))])
        self.residualBatch    = nn.ModuleList([nn.BatchNorm2d(sameChannel_Out[i]) \
        for i in range(len(sameChannel_Out))])

    def interpolation(self, inputT, targetT):
        """This function upsamples an input tensor to the spatial dimensions of the target tensor using a 
        combination of bilinear interpolation and convolution."""
        inputT_ch = inputT.shape[1]
        inputT_H  = inputT.shape[2]
        inputT_W  = inputT.shape[3]
        
        targetT_ch= targetT.shape[1]
        targetT_H = targetT.shape[2]
        targetT_W = targetT.shape[3]
             
        interpol  = nn.functional.interpolate(inputT, size = (targetT_H, targetT_W), mode = "bilinear")
        return interpol
    
    def forward(self, out, enList):
        enList     = enList[: : -1] # reversing the order of elements inside the list.
        
        for l in range(len(self.channelLength)):
            
            ### Upscaling = interpolation + convolution ###
            out    = self.interpolation(out, enList[l])
            out    = self.upConv[l](out)
            out    = self.upBatch[l](out)
            out    = self.upActive(out)
            ### Concatenation ###
            out    = torch.cat([out, enList[l]], 1) # concatenation on channel number!
            ### Decoding ###
            out    = self.deConvList[l](out)
            out    = self.batch[l](out)
            out    = self.active(out)
            ### Residual layers
            inputToResidual = out
            ## first layer
            out  = self.residualDecoding[l](out)
            out  = self.residualBatch[l](out)
            out  = self.active(out)
            ## second layer
            out  = self.residualDecoding[l](out)
            out  = self.residualBatch[l](out)
            out  = self.active(out)
            ## third layer
            out  = self.residualDecoding[l](out)
            out  = self.residualBatch[l](out)
            out  = self.active(out)
            ## Addition of the input of the residual block to the output of the residual block!
            out  = inputToResidual + out
            
        ### This is the last convolution layer!
        out      = self.lastConv(out)
        return out

This time, as you can see, I initialise the Residual layer in the ___init___ section. The point I would like to raise is that I am initialising a list of six different residual layers in the Encoding and Decoding blocks BUT, in the forward methods, I am using each member of the Residual list THREE times before I add the input of the Residual layer to the output of the Residual layer. The question is when I want to use a layer three times (such as the case above) can I initialise it as one layer and repeat it in the forward method several times (here three) or do I have to initialise a layer as many times as I want to use that layer. In other words, should I have created three identical lists of the Residual layers under different names each with six identical layers and should have used, for instance, each list in place of each of the Residual layers above? Something similar to the following code snippet? (Please note the change of names of the lists.)


### Residual layers
inputToResidual = out
 ## first layer
out  = self.residualDecodingOne[l](out)
out  = self.residualBatchOne[l](out)
out  = self.active(out)
## second layer
out  = self.residualDecodingTwo[l](out)
out  = self.residualBatchTwo[l](out)
out  = self.active(out)
## third layer
out  = self.residualDecodingThree[l](out)
out  = self.residualBatchThree[l](out)
out  = self.active(out)
## Addition of the input of the residual block to the output of the residual block!
out  = inputToResidual + out

This question occurred to me as a result of printing out the number and name of the trainable parameters of the model (see below). (Here only the trainable parameters of the Residual layers of the Decoding block are shown where the first column shows the number of parameters and the second column shows the name of the layer).

1327104 decoder.residualDecoding.0.weight
384 decoder.residualDecoding.0.bias
921600 decoder.residualDecoding.1.weight
320 decoder.residualDecoding.1.bias
589824 decoder.residualDecoding.2.weight
256 decoder.residualDecoding.2.bias
331776 decoder.residualDecoding.3.weight
192 decoder.residualDecoding.3.bias
147456 decoder.residualDecoding.4.weight
128 decoder.residualDecoding.4.bias
36864 decoder.residualDecoding.5.weight
64 decoder.residualDecoding.5.bias
384 decoder.residualBatch.0.weight
384 decoder.residualBatch.0.bias
320 decoder.residualBatch.1.weight
320 decoder.residualBatch.1.bias
256 decoder.residualBatch.2.weight
256 decoder.residualBatch.2.bias
192 decoder.residualBatch.3.weight
192 decoder.residualBatch.3.bias
128 decoder.residualBatch.4.weight
128 decoder.residualBatch.4.bias
64 decoder.residualBatch.5.weight
64 decoder.residualBatch.5.bias

While I am not surprised that the model prints out six layers of weights and biases because I have declared a Residual list with six layers but I am confused as to what happens to the weights and biases of the repeated layers in the forward function. Assuming the model actually sees them as trainable parameters and trains them, then why did they not appear as trainable parameters when I printed the list of trainable parameters above? The same questions holds true for the BatchNorm2d layers too!

Both approaches are valid and it depends, if you want to reuse the parameters (and layers).
If you initialize a layer once, it’ll have one set of parameters, which will be optimized. In the forward method you can reuse this layer as many times as you want as seen here in a simple example:

lin = nn.Linear(1, 1, bias=False)
optimizer = torch.optim.SGD(lin.parameters(), lr=1.)

x = torch.randn(1, 1)

# single step
out = lin(x)
out.backward()
print(lin.weight.grad)
optimizer.step()
lin.zero_grad()

# reuse the layer
out = lin(x)
for _ in range(10):
    out = lin(out)
out.backward()
print(lin.weight.grad)
optimizer.step()
lin.zero_grad()

This “reuse” of the layer is sometimes also used and called “parameter sharing”. E.g. you could initialize specific layers in an encoder and “share” them in the decoder, which would mean that you either directly reuse them in the decoder path or reshape the parameters somehow so that they can be used in another part of the model.

I’m not familiar with your use case so don’t know which approach would be the right one for your model.
However, be a bit careful in reusing batchnorm layers, as they are tracking the running stats by default and I’ve seen some users reporting that reusing these layers is often failing due to this.

After fixing the Residual issue I trained and ran the model and this time the results are consistent during inference. I appreciate your help. Thank you very much :slight_smile: