Hello.
During the training process, I attempt to manually modify parameters under special condition, but find that they automatically revert back to their previous states. How can I prevent this from happening in PyTorch?
The code:
def _train_epoch(self, train_data,aux_train_data, epoch_idx, show_progress=False):
r"""Train the model in an epoch
Args:
train_data (DataLoader): The train data.
epoch_idx (int): The current epoch id.
loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be
:attr:`self.model.calculate_loss`. Defaults to ``None``.
show_progress (bool): Show the progress of training epoch. Defaults to ``False``.
Returns:
float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains
multiple parts and the model return these multiple parts loss instead of the sum of loss, it will return a
tuple which includes the sum of loss in each part.
"""
self.model.train()
total_loss = None
iter_data = ( #
tqdm(
train_data,
total=len(train_data),
ncols=100,
desc=set_color(f"Train {epoch_idx:>5}", "pink"),
)
if show_progress
else train_data
)
if not self.config["single_spec"] and train_data.shuffle:
train_data.sampler.set_epoch(epoch_idx)
scaler = amp.GradScaler(enabled=self.enable_scaler)
batchcount = 0
dataloader_iterator = iter(aux_train_data)
for batch_index, interaction in enumerate(iter_data):
try:
interaction_aux = next(dataloader_iterator)
except StopIteration:
dataloader_iterator = iter(aux_train_data)
interaction_aux= next(dataloader_iterator)
# for interaction, interaction_aux in zip(iter_data,aux_train_data):
interaction = interaction.to(self.device)
interaction_aux = interaction_aux.to(self.device)
self.optimizer.zero_grad()
sync_loss = 0
if not self.config["single_spec"]:
self.set_reduce_hook()
sync_loss = self.sync_grad_loss()
with torch.autocast(device_type=self.device.type, enabled=self.enable_amp):
# losses = self.model.calculate_loss(interaction,interaction_aux,centroid_emb)
losses = self.model.calculate_loss(interaction, interaction_aux)
if isinstance(losses, tuple):
loss = sum(losses)
loss_tuple = tuple(per_loss.item() for per_loss in losses)
total_loss = (
loss_tuple
if total_loss is None
else tuple(map(sum, zip(total_loss, loss_tuple)))
)
else:
loss = losses
total_loss = (
losses.item() if total_loss is None else total_loss + losses.item()
)
self._check_nan(loss)
scaler.scale(loss + sync_loss).backward(retain_graph=True)
if self.clip_grad_norm:
clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm)
scaler.step(self.optimizer)
scaler.update()
if self.gpu_available and show_progress:
iter_data.set_postfix_str(
set_color("GPU RAM: " + get_gpu_usage(self.device), "yellow")
)
self.model.feat_mat_anneal()
if self.model.lotteryflag and epoch_idx % 10 == 0 and epoch_idx > 10:
print("prune!")
self.model.update_norm_adj()
return total_loss
def update_norm_adj(self,prune_percent = 0.05):
with torch.no_grad():
adj_mask_tensor = self.adj_mask_train.flatten()
print("adj_mask_tensor.shape: ", adj_mask_tensor.shape)
nonzero = torch.abs(adj_mask_tensor) > 0
adj_mask = adj_mask_tensor[nonzero]
print("adj_mask.shape: ", adj_mask.shape)
adj_total = adj_mask.shape[0]
adj_y, adj_i = torch.sort(adj_mask.abs())
adj_thre_index = int(adj_total * prune_percent)
adj_thre = adj_y[adj_thre_index]
print("adj_thre",adj_thre)
abs_values = torch.abs(self.adj_mask_train)
mask = abs_values >= adj_thre
self.adj_mask_train.data[mask] = 1
self.adj_mask_train.data[~mask] = 0
if self.adj_mat.is_sparse:
dense_tensor = self.adj_mat.to_dense()
else:
dense_tensor = self.adj_mat
# print("dense_tensor.shape: ", dense_tensor.shape)
# print("adj_mask.shape: ", adj_mask.shape)
zeros_count = (self.adj_mask_train == 0).sum().item()
ones_count = (self.adj_mask_train == 1).sum().item()
print(f"Number of zeros: {zeros_count}")
print(f"Number of ones: {ones_count}")
adj = torch.mul(dense_tensor, self.adj_mask_train)
self.norm_adj_matrix = self.torch_normalize_adj(adj)
--------
Specifically, I aim to perform pruning once at fixed epoch intervals, which involves modifying the learnable parameters:
self.adj_mask_train = torch.nn.Parameter(self.generate_adj_mask(self.generate_daj_mat()),requires_grad=True)
However, I am encountering an issue where, although pruning appears successful and I can indeed access the modified parameters via self.model.adj_mask_train, upon entering the next epoch and reaching the part of the code:
scaler.step(self.optimizer)
scaler.update()
I find that the parameters adj_mask_train in the model have reverted back to their pre-pruning state (the initial state without pruning) . What I need is that adj_mask_train can be gradually iterated to the state I want. However, the current situation seems to be that all my modifications to the parameters will return to the original state in the next epoch.
My problem may be similar to Set nn.Parameter during training. But I couldn’t find an answer to this question. It seems that the last content added by the questioner is the possible cause of this problem?