batch_first=True gives different results in RNN

I am trying to rewrite some code to make them batch_first=True.
Then, I realized that the results when setting batch_first=True is different from that of batch_first=False even if I am using the same random seed.

Here’s my original code written with batch_first=False

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

import numpy as np

import random
import math
import time

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = 'cuda:1'

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, ), (0.5, ))])
trainset = MNIST(root='./data', 
                                        train=True,
                                        download=True,
                                        transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,
                                            batch_size=100,
                                            shuffle=True,
                                            num_workers=2)

testset = MNIST(root='./data', 
                                        train=False, 
                                        download=True, 
                                        transform=transform)
testloader = torch.utils.data.DataLoader(testset, 
                                            batch_size=100,
                                            shuffle=False, 
                                            num_workers=2)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Linear(28,14)        
        self.rnn = nn.GRU(14, 14, batch_first=False)
        self.fc = nn.Linear(14,10)
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x):
        x = self.dropout(self.emb(x))
        outputs, hidden = self.rnn(x)
        output = torch.sigmoid(self.fc(hidden.squeeze(0)))

        return output

model = Model()
model = model.to(device)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

for e in range(5):
    train_total_loss = 0
    model.train()
    for i, (x,y) in enumerate(trainloader):
        x = x.to(device)
        y = y.to(device)
        x = x.squeeze(1).transpose(0,1)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        train_total_loss += loss.item()
        print(f'{i}/{len(trainloader)}: loss = {loss.item()}', end='\r')
        

    test_total_loss = 0
    model.eval()
    for i, (x,y) in enumerate(testloader):
        x = x.to(device)
        y = y.to(device)      
        x = x.squeeze(1).transpose(0,1)
        y_pred = model(x)
        loss = criterion(y_pred, y)
        test_total_loss += loss.item()
        print(f'{i}/{len(testloader)}: loss = {loss.item()}', end='\r')        
    print(f'epoch {e+1}: train loss={train_total_loss/len(trainloader):.4f}\ttest loss={test_total_loss/len(testloader):.4f}')

output >>>
epoch 1: train loss=2.1462	test loss=1.9462
epoch 2: train loss=1.9133	test loss=1.8628
epoch 3: train loss=1.8353	test loss=1.7833
epoch 4: train loss=1.7718	test loss=1.7382
epoch 5: train loss=1.7444	test loss=1.7191

Now I change my model to batch_first=False

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Linear(28,14)        
        self.rnn = nn.GRU(14, 14, batch_first=True)
        self.fc = nn.Linear(14,10)
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x):
        

        
        x = self.dropout(self.emb(x))
        outputs, hidden = self.rnn(x)
        output = torch.sigmoid(self.fc(hidden.squeeze(0)))
     
        return output

model = Model()
model = model.to(device)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

for e in range(5):
    train_total_loss = 0
    model.train()
    for i, (x,y) in enumerate(trainloader):
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        y_pred = model(x.squeeze(1))
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        train_total_loss += loss.item()
        print(f'{i}/{len(trainloader)}: loss = {loss.item()}', end='\r')
        

    test_total_loss = 0
    model.eval()
    for i, (x,y) in enumerate(testloader):
        x = x.to(device)
        y = y.to(device)        
        y_pred = model(x.squeeze(1))
        loss = criterion(y_pred, y)
        test_total_loss += loss.item()
        print(f'{i}/{len(testloader)}: loss = {loss.item()}', end='\r')        
    print(f'epoch {e+1}: train loss={train_total_loss/len(trainloader):.4f}\ttest loss={test_total_loss/len(testloader):.4f}')

output>>>
epoch 1: train loss=2.1360	test loss=1.9413
epoch 2: train loss=1.9016	test loss=1.8367
epoch 3: train loss=1.8216	test loss=1.7775
epoch 4: train loss=1.7840	test loss=1.7494
epoch 5: train loss=1.7612	test loss=1.7350

The result is different.
What is wrong here?

I cannot reproduce the difference for the nn.GRU module in isolation:

device = 'cuda'
x0 = torch.randn(5, 6, 14, device=device)
model0 = nn.GRU(14, 14, batch_first=True).to(device)
out0 = model0(x0)

x1 = x0.permute(1, 0, 2)
model1 = nn.GRU(14, 14, batch_first=False).to(device)
model1.load_state_dict(model0.state_dict())
out1 = model1(x1)

print((out0[0].permute(1, 0, 2) - out1[0]).abs().max())
> tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
print((out0[1] - out1[1]).abs().max())
> tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)

So the difference is most likely coming from another part of your code.
I would recommend to slim down the code by e.g. removing the data loading etc. to get rid of potential sources which could add randomness.
Also, since your model is using dropout you have to be very careful to seed to code in the appropriate place and make sure the same calls to the pseudo-random number generator are performed in both scripts. Thus I would just remove it or call model.eval().

Thanks for your insight. Indeed, the problem comes from the dropout layer.
Even when I make sure the seed for both script are the same

import torch
import torch.nn as nn
import numpy as np
import random
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

x_batch_first = torch.ones(2,3,4)
x_batch_second = torch.ones(3,2,4)

dropout_layer = nn.Dropout(p=0.5)

output1 = dropout_layer(x_batch_first)
output2 = dropout_layer(x_batch_second)
torch.equal(output1, output2.transpose(0,1))
output >>> False

When I print out both results, it gives me:

output1
tensor([[[0., 2., 0., 0.],
         [2., 0., 2., 0.],
         [0., 0., 2., 2.]],

        [[2., 0., 2., 2.],
         [0., 0., 2., 2.],
         [0., 2., 2., 0.]]])

output2
tensor([[[2., 0., 2., 0.],
         [2., 2., 0., 2.]],

        [[0., 2., 2., 2.],
         [2., 0., 2., 0.]],

        [[2., 2., 0., 2.],
         [2., 0., 0., 0.]]])

Any ways to make sure dropout layer still behaves the same when batch_first=True and batch_first=False?

The output of your current script is expected, since each call into dropout_layer will sample a new random mask.
The initial seed makes sure that sequential calls into the random number generator will yield the same “random” numbers, but of course it’s not forcing dropout layers to sample a constant mask.

If you want to force the dropout layers to behave exactly the same, you could either seed the code right before the call (which would get rid of the randomness and thus I would assume your model to perform badly) via:

x_batch_first = torch.ones(2,3,4)
x_batch_second = torch.ones(3,2,4)

dropout_layer = nn.Dropout(p=0.5)

torch.manual_seed(2809)
output1 = dropout_layer(x_batch_first)
torch.manual_seed(2809)
output2 = dropout_layer(x_batch_second.transpose(0, 1))
torch.equal(output1, output2)
> True

or maybe sample the dropout masks manually outside of the training script and just apply them.

Note that even seeding the layer would give different results, if the tensors have different shapes, which is why I’ve transposed the second input before the dropout call.

1 Like