Model loaded for inference performs like a random init

I’ve trained a Transformer to perform translation. The final loss before saving the state dict was ~20, on reloading and running inference it is ~82 (close to starting loss which was 103).
I’ve been banging my head against the wall for days on this one. Does anyone have any idea on where to start with finding the issue?

Setup details:
torch version: 1.7.1
device: 2x GPU
Below I put my code snippets:

Saving the model:

torch.save(model.state_dict(), best_path)

Loading the model:

model = Model(config, vocab)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model.load_state_dict(torch.load(best_path))
model.to(config.device)
model.eval()

I set the random seed at the start of each session like this:

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(config.seed)

I use the same pre-tokenised (SPM) data each time. I load it in the following way:


def dataset_fn(df, root, fields, seed=1234, dev_size=1000):
    train, dev, test = np.split(df.sample(frac=1, random_state=seed), [len(df) - 2 * dev_size, len(df) - dev_size])
    train.to_csv(os.path.join(root, "train.csv"), index=False)
    dev.to_csv(os.path.join(root, "dev.csv"), index=False)
    test.to_csv(os.path.join(root, "test.csv"), index=False)
    train, dev, test = TabularDataset.splits(
        path=root,
        train='train.csv',
        validation='dev.csv',
        test='test.csv',
        format='csv',
        fields=fields)
    return train, dev, test

device = config.device
print("Device:{}".format(device))
root = config.data_path
seed = config.seed
data_paths = {'pre_src': os.path.join(root, config.pre_src_path),
                  'pre_trg': os.path.join(root, config.pre_trg_path)}
for key, value in data_paths.items():
    df = pd.read_csv(value, sep='delimiter', index_col=None, header=None, skip_blank_lines=False)
    dfs[key] = df
pre_df = make_df(dfs['pre_src'], dfs['pre_trg'])
print("Data loaded.\n")

TEXT = Field(tokenize=tokenize,
             init_token='<sos>',
             eos_token='<eos>',
             lower=False,
             batch_first=True)
print("Field defined")
data_fields = [('src', TEXT), ('trg', TEXT)]

train_pre_set, dev_pre_set, test_pre_set = dataset_fn(pre_df, root, data_fields, seed=seed, dev_size=1000)

# vocab is read using pickle from a .pkl file, previously generated using .build_vocab on this data and saved to .pkl
vocab = read_vocab(config.vocab_file)
TEXT.vocab = vocab

dataiter_fn = lambda dataset, train: BucketIterator(
    dataset=dataset,
    batch_size=config.batch_size,
    shuffle=train,
    repeat=train,
    sort_key=lambda x: len(x.trg),
    sort_within_batch=False,
    device=device
)
# Create iterators
train_pre_iter = dataiter_fn(train_pre_set, True)
dev_pre_iter = dataiter_fn(dev_pre_set, False)
test_pre_iter = dataiter_fn(test_pre_set, False)
# This is wrapped in a function which returns iters & vocab
#return train_pre_iter, dev_pre_iter, test_pre_iter, vocab

Let me know if you need any more info!

These issues are often caused by a difference in the model usage or the data loading.
To isolate the root cause further, you could store the outputs of the trained model using a static input (e.g. torch.ones) after calling model.eval() and before saving the state_dict().
Afterwards, load the model in your inference script and compare the reference outputs to a new run using the same static inputs. If these outputs differ, the difference would come from the model itself and you could debug further what might be causing the difference (missing keys in the state_dict etc.).

On the other hand, if the outputs are equal (up to floating point precision), you could compare the data loading pipelines in your training and inference scripts and check, if the same preprocessing was applied (e.g. normalization etc.).

Thanks @ptrblck . I was about to do this after reading this suggestion of yours somewhere else, when miraculously it worked now (touch wood).

The only thing I changed was the way I save and load the model like this (I’m using DataParallel):

model = MyModel(config, vocab)

model.to(config.device)
print(config.device)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model.module.load_state_dict(torch.load(load_path))
model.eval()

And save:

torch.save(model.module.state_dict(), best_path)

In other words, I’m saving the state dict directly rather than it wrapped in the module.
Does it make any sense that this could be causing the issue?

It could explain the issue, but you should get a warning while loading e.g. the module.state_dict() into the nn.DataParallel model.
The reason for this is that the nn.DataParallel(model).state_dict() will add the .module names to each parameter, which will then create a mismatch in the model.load_state_dict() operation (the same applies for the reversed workflow).
However, if no errors were raised, I’m unsure what might have been the root cause of the issue.

Hi @st-vincent1, I’m having a similar issue with the loaded model. Were you able to find what’s causing the problem?

It shouldn’t raise a warning because I saved a model.state_dict() while model was wrapped in nn.DataParallel(), and then loading it by first wrapping an initialised model in DataParallel and then loading the state dict into that. But that didn’t work - hence my original post. Once I changed to saving the state dict explicitly (not wrapped in the module) and loading it into model.module.state_dict(), then it all worked.

@pattiJane , see above how I worked it out.