Getting error when evaluating JODIE model

Hello, I am trying to reproduce the code of Kumar’s JODIE model available on the following github GitHub - srijankr/jodie: A PyTorch implementation of ACM SIGKDD 2019 paper "Predicting Dynamic Embedding Trajectory in Temporal Interaction Networks" with pytorch version 1.7.1. Everything goes fine during training but for model evaluation I get the following error RuntimeError: Function ‘MseLossBackward’ returned nan values in its 0th output. the code used is the following:

def evaluate(data, embedding_dim, state_change, train_proportion, model, optimizer, user_embeddings_dystat, item_embeddings_dystat, user_embeddings_timeseries, item_embeddings_timeseries):

torch.autograd.set_detect_anomaly(True)
df = data.to_numpy()

[user2id, user_sequence_id, user_timediffs_sequence, user_previous_itemid_sequence,
item2id, item_sequence_id, item_timediffs_sequence,
timestamp_sequence, feature_sequence, y_true] = load_network(df)
num_interactions = len(user_sequence_id)
num_users = len(user2id)
num_items = len(item2id) + 1
num_features = len(feature_sequence[0])
true_labels_ratio = len(y_true) / (1.0 + sum(y_true))
print(“*** Network statistics: \n %d users\n %d items\n %d interactions\n %d %d true labels ***\n\n” % (num_users, num_items, num_interactions, sum(y_true), len(y_true)))

train_end_idx = validation_start_idx = int(num_interactions * train_proportion)
test_start_idx = int(num_interactions * (train_proportion + 0.2))
test_end_idx = int(num_interactions * (train_proportion + 0.4))

timespan = timestamp_sequence[-1] - timestamp_sequence[0]
tbatch_timespan = timespan / 500

model = JODIE(num_features, num_users, num_items, embedding_dim)
weight = torch.Tensor([1, true_labels_ratio])
crossEntropyLoss = nn.CrossEntropyLoss(weight = weight)
MSELoss = nn.MSELoss()

learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr = learning_rate, weight_decay = 1e-5)

set_embeddings_training_end(user_embeddings_dystat, item_embeddings_dystat, user_embeddings_timeseries, item_embeddings_timeseries, user_sequence_id, item_sequence_id, train_end_idx)

item_embeddings = item_embeddings_dystat[:, :embedding_dim]
item_embeddings = item_embeddings.clone()
item_embeddings_static = item_embeddings_dystat[:, embedding_dim:]
item_embeddings_static = item_embeddings_static.clone()

user_embeddings = user_embeddings_dystat[:, :embedding_dim]
user_embeddings = user_embeddings.clone()
user_embeddings_static = user_embeddings_dystat[:, embedding_dim:]
user_embeddings_static = user_embeddings_static.clone()

validation_predicted_y = []
test_predicted_y = []
validation_true_y = []
test_true_y = []

tbatch_start_time = None
loss = 0

with trange(train_end_idx, test_end_idx) as progress_bar:
for j in progress_bar:
progress_bar.set_description(“%dth interaction for validation and testing” % j)

  	userid = user_sequence_id[j]
  	itemid = item_sequence_id[j]
  	feature = feature_sequence[j]
  	user_timediff = user_timediffs_sequence[j]
  	item_timediff = item_timediffs_sequence[j]
  	timestamp = timestamp_sequence[j]

  	
  	if not tbatch_start_time:

  		tbatch_start_time = timestamp

  	itemid_previous = user_previous_itemid_sequence[j]
  	user_embedding_input = user_embeddings[torch.LongTensor([userid])]
  	user_embedding_static_input = user_embeddings_static[torch.LongTensor([userid])]
  	item_embedding_input = item_embeddings[torch.LongTensor([itemid])]
  	item_embedding_static_input = item_embeddings_static[torch.LongTensor([itemid])]
  	feature_tensor = Variable(torch.Tensor([feature])).unsqueeze(0)
  	user_timediffs_tensor = Variable(torch.Tensor([user_timediff])).unsqueeze(0)
  	item_timediffs_tensor = Variable(torch.Tensor([item_timediff])).unsqueeze(0)
  	item_embedding_previous = item_embeddings[torch.LongTensor([itemid_previous])]

  	user_projected_embedding = model.forward(user_embedding_input, item_embedding_previous, timediffs = user_timediffs_tensor, features = feature_tensor, select = "project")
  	user_item_embedding = torch.cat([user_projected_embedding, item_embedding_previous, item_embeddings_static[torch.LongTensor([itemid_previous])], user_embedding_static_input], dim = 1)

  	predicted_item_embedding = model.predicted_item_embedding(user_item_embedding)

  	loss = loss + MSELoss(predicted_item_embedding, torch.cat([item_embedding_input, item_embedding_static_input], dim = 1).clone())

  	user_embedding_output = model.forward(user_embedding_input, item_embedding_input, timediffs = user_timediffs_tensor, features = feature_tensor.reshape(1,-1), select = "user_update")
  	item_embedding_output = model.forward(user_embedding_input, item_embedding_input, timediffs = item_timediffs_tensor, features = feature_tensor.reshape(1,-1), select = "item_update")

  	item_embeddings[itemid, :] = item_embedding_output.squeeze(0)
  	user_embeddings[userid, :] = user_embedding_output.squeeze(0)
  	user_embeddings_timeseries[j, :] = user_embedding_output.squeeze(0)
  	item_embeddings_timeseries[j, :] = item_embedding_output.squeeze(0)

  	loss = loss + MSELoss(item_embedding_output, item_embedding_input.clone())
  	loss = loss + MSELoss(user_embedding_output, user_embedding_input.clone())


  	if state_change:
  		loss = loss + calculate_state_prediction_loss(model, [j], user_embeddings_timeseries, y_true, crossEntropyLoss)

  	if timestamp - tbatch_start_time > tbatch_timespan:
  		print("ok")
  		tbatch_start_time = timestamp
  		loss.backward()
  		optimizer.step()
  		optimizer.zero_grad()

  		loss = 0
  		item_embeddings.detach_()
  		user_embeddings.detach_()
  		item_embeddings_timeseries.detach_()
  		user_embeddings_timeseries.detach_()

  	prob = model.predict_label(user_embedding_output)

  	if j < test_start_idx:
  		validation_predicted_y.extend(prob.data.cpu().numpy())
  		validation_true_y.extend([y_true[j]])
  	else:
  		test_predicted_y.extend(prob.data.cpu().numpy())
  		test_true_y.extend([y_true[j]])

validation_predicted_y = np.array(validation_predicted_y)
test_predicted_y = np.array(test_predicted_y)

performance_dict = dict()
auc = roc_auc_score(validation_true_y, validation_predicted_y[:, 1])
performance_dict[“validation”] = [auc]

auc = roc_auc_score(test_true_y, test_predicted_y[:, 1])
performance_dict[“test”] = [auc]

return performance_dict

Can you help me please