Simplified code:
def update_norm_adj(self,prune_percent = 0.05):
# Prune adj_mask, executed every 20 epochs
with torch.no_grad():
adj_mask_tensor = self.adj_mask_train.flatten()
nonzero = torch.abs(adj_mask_tensor) > 0
adj_mask = adj_mask_tensor[nonzero]
adj_total = adj_mask.shape[0]
adj_y, adj_i = torch.sort(adj_mask.abs())
adj_thre_index = int(adj_total * prune_percent)
adj_thre = adj_y[adj_thre_index]
abs_values = torch.abs(self.adj_mask_train)
mask = abs_values >= adj_thre
self.adj_mask_train.data[mask] = 1
self.adj_mask_train.data[~mask] = 0
self.pruneflag = True
print("prune success!")
def forward(self):
if self.lotteryflag and self.pruneflag
#In each forward phase, the adjacency matrix is recalculated
dense_tensor = self.adj_mat
adj = torch.mul(dense_tensor, self.adj_mask_train)
self.norm_adj_matrix = self.torch_normalize_adj(adj)
all_embeddings = self.get_ego_embeddings()
feat_mat = self.dropout_sp_mat(self.feat_mat)
embeddings_list = [all_embeddings]
for layer_idx in range(self.n_layers):
all_embeddings = torch.mm(self.norm_adj_matrix, all_embeddings)
embeddings_list.append(all_embeddings)
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)
user_all_embeddings, item_all_embeddings = torch.split(
lightgcn_all_embeddings, [self.n_users, self.n_items]
)
return user_all_embeddings, item_all_embeddings
def torch_normalize_adj(self,adj):
adj = adj + torch.eye(adj.shape[0]).cuda()
rowsum = adj.sum(1)
d_inv_sqrt = torch.pow(rowsum, -0.5).flatten()
d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.0
d_mat_inv_sqrt = torch.diag(d_inv_sqrt).cuda()
result = adj.mm(d_mat_inv_sqrt).t().mm(d_mat_inv_sqrt)
return result
def _train_epoch(self, train_data,aux_train_data, epoch_idx, show_progress=False):
# Train the model in an epoch
self.model.train()
total_loss = None
iter_data = (
tqdm(
train_data,
total=len(train_data),
ncols=100,
desc=set_color(f"Train {epoch_idx:>5}", "pink"),
)
if show_progress
else train_data
)
if not self.config["single_spec"] and train_data.shuffle:
train_data.sampler.set_epoch(epoch_idx)
scaler = amp.GradScaler(enabled=self.enable_scaler)
batchcount = 0
dataloader_iterator = iter(aux_train_data)
for batch_index, interaction in enumerate(iter_data):
try:
interaction_aux = next(dataloader_iterator)
except StopIteration:
dataloader_iterator = iter(aux_train_data) #新建一个
interaction_aux= next(dataloader_iterator)
interaction = interaction.to(self.device)
interaction_aux = interaction_aux.to(self.device)
self.optimizer.zero_grad()
sync_loss = 0
if not self.config["single_spec"]:
self.set_reduce_hook()
sync_loss = self.sync_grad_loss()
with torch.autocast(device_type=self.device.type, enabled=self.enable_amp):
losses = self.model.calculate_loss(interaction, interaction_aux)
if isinstance(losses, tuple):
loss = sum(losses)
loss_tuple = tuple(per_loss.item() for per_loss in losses)
total_loss = (
loss_tuple
if total_loss is None
else tuple(map(sum, zip(total_loss, loss_tuple)))
)
else:
loss = losses
total_loss = (
losses.item() if total_loss is None else total_loss + losses.item()
)
self._check_nan(loss)
# scaler.scale(loss + sync_loss).backward(retain_graph=True)
scaler.scale(loss + sync_loss).backward()
if self.clip_grad_norm:
clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm)
scaler.step(self.optimizer)
scaler.update()
if self.gpu_available and show_progress:
iter_data.set_postfix_str(
set_color("GPU RAM: " + get_gpu_usage(self.device), "yellow")
)
if self.model.lotteryflag and epoch_idx % 10 == 0 and epoch_idx > 0:
print("prune!")
self.model.update_norm_adj()
self.model.zero_grad()
return total_loss
(more code can be found here: Why do manually modified parameters in training process revert to their original values after optimizer.step() is executed?)