I’ve trained a Transformer to perform translation. The final loss before saving the state dict was ~20, on reloading and running inference it is ~82 (close to starting loss which was 103).
I’ve been banging my head against the wall for days on this one. Does anyone have any idea on where to start with finding the issue?
Setup details:
torch version: 1.7.1
device: 2x GPU
Below I put my code snippets:
Saving the model:
torch.save(model.state_dict(), best_path)
Loading the model:
model = Model(config, vocab)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.load_state_dict(torch.load(best_path))
model.to(config.device)
model.eval()
I set the random seed at the start of each session like this:
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(config.seed)
I use the same pre-tokenised (SPM) data each time. I load it in the following way:
def dataset_fn(df, root, fields, seed=1234, dev_size=1000):
train, dev, test = np.split(df.sample(frac=1, random_state=seed), [len(df) - 2 * dev_size, len(df) - dev_size])
train.to_csv(os.path.join(root, "train.csv"), index=False)
dev.to_csv(os.path.join(root, "dev.csv"), index=False)
test.to_csv(os.path.join(root, "test.csv"), index=False)
train, dev, test = TabularDataset.splits(
path=root,
train='train.csv',
validation='dev.csv',
test='test.csv',
format='csv',
fields=fields)
return train, dev, test
device = config.device
print("Device:{}".format(device))
root = config.data_path
seed = config.seed
data_paths = {'pre_src': os.path.join(root, config.pre_src_path),
'pre_trg': os.path.join(root, config.pre_trg_path)}
for key, value in data_paths.items():
df = pd.read_csv(value, sep='delimiter', index_col=None, header=None, skip_blank_lines=False)
dfs[key] = df
pre_df = make_df(dfs['pre_src'], dfs['pre_trg'])
print("Data loaded.\n")
TEXT = Field(tokenize=tokenize,
init_token='<sos>',
eos_token='<eos>',
lower=False,
batch_first=True)
print("Field defined")
data_fields = [('src', TEXT), ('trg', TEXT)]
train_pre_set, dev_pre_set, test_pre_set = dataset_fn(pre_df, root, data_fields, seed=seed, dev_size=1000)
# vocab is read using pickle from a .pkl file, previously generated using .build_vocab on this data and saved to .pkl
vocab = read_vocab(config.vocab_file)
TEXT.vocab = vocab
dataiter_fn = lambda dataset, train: BucketIterator(
dataset=dataset,
batch_size=config.batch_size,
shuffle=train,
repeat=train,
sort_key=lambda x: len(x.trg),
sort_within_batch=False,
device=device
)
# Create iterators
train_pre_iter = dataiter_fn(train_pre_set, True)
dev_pre_iter = dataiter_fn(dev_pre_set, False)
test_pre_iter = dataiter_fn(test_pre_set, False)
# This is wrapped in a function which returns iters & vocab
#return train_pre_iter, dev_pre_iter, test_pre_iter, vocab
Let me know if you need any more info!