Initializing weights before an SGD update

This is a niche question given my esoteric training strategy - I’m wondering if there is a straightforward way to initialize a parameter before every SGD update if you already have a really good guess what the parameter should be.

To briefly provide some context, I have an alternating minimization scheme, where for a given batch I apply multiple SGD updates between 2 optimizers as follows

    # see pytorch-lightning for module organization
    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i,
                       second_order_closure, on_tpu=False,
                       using_native_amp=False, using_lbfgs=False):
        # perform multiple steps with SGD to optimize eta
        if optimizer_i == 0:
            for _ in range(self.hparams.steps_per_batch):
                loss = second_order_closure()
                optimizer.step(second_order_closure)
                optimizer.zero_grad()

        # update all of the other parameters once
        # eta is optimized
        if optimizer_i == 1:
            for _ in range(self.hparams.steps_per_batch):
                loss = second_order_closure()
                optimizer.step(second_order_closure)
                optimizer.zero_grad()

    def training_step(self, batch, batch_idx, optimizer_idx):
        self.model.train()
        counts = batch
        self.model.reset(batch)
        loss = self.model(counts)
        assert torch.isnan(loss).item() is False
        if len(self.trainer.lr_schedulers) >= 1:
            lr = self.trainer.lr_schedulers[0]['scheduler'].get_last_lr()[0]
            current_lr = lr
        else:
            current_lr = self.hparams.learning_rate
        tensorboard_logs = {
            'train_loss': loss, 'elbo': -loss, 'lr': current_lr
        }
        # log the learning rate
        return {'loss': loss, 'log': tensorboard_logs}

Here’s the actual question - I have a function that can provide a good guess on what the parameter should me, all I want to do is initialize this parameter so that I can refine the estimate of this parameter with following SGD updates - I’m trying to do this as follows.

    def reset(self, x): 
        hx = my_awesome_guess(x)
        self.eta.data.copy_(hx)

However, I’m noticing that this sort of operation essentially freezes the weights of eta, and I cannot apply SGD updates to update eta.

I should think this initialization strategy should be possible, right?

Could you explain what “freezes” means in this context? Is this parameter not getting any valid gradients?
Since the usage of the .data attribute is not recommended, you could wrap the code in a with torch.no_grad() block and remove the .data part. However, this might not get rid of the “freezing” issue.

Thanks @ptrblck for following up. Freezing here means that there are zero gradient updates on eta whenever optimizer.step() is called.

The training runs fine - and I am able to alter the values of eta for each batch. However my per-batch initialization scheme seems to disable any gradient descent steps on eta. I have tried using a torch.no_grad() block, but that also looks like it freezes eta as well.

UPDATE 1: Here is a better illustration – if I put in print statements after each optimizer step, I’ll get the following output

    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i,
                       second_order_closure, on_tpu=False,
                       using_native_amp=False, using_lbfgs=False):
        # perform multiple steps with LBFGS to optimize eta
        if optimizer_i == 0:
            for _ in range(self.hparams.steps_per_batch):
                loss = second_order_closure()
                print('current_epoch', current_epoch,
                      'batch', batch_nb, 'optimizer', optimizer_i, loss)
                optimizer.step(second_order_closure)
                optimizer.zero_grad()

        # update all of the other parameters once
        # eta is optimized
        if optimizer_i == 1:
            for _ in range(self.hparams.steps_per_batch):
                loss = second_order_closure()
                print('current_epoch', current_epoch,
                      'batch', batch_nb, 'optimizer', optimizer_i, loss)
                optimizer.step(second_order_closure)
                optimizer.zero_grad()

        loss_ = second_order_closure().item()
        self.logger.experiment.add_scalar(
            'train_loss', loss_, self.global_step)

Output (for 10 steps per batch per optimizer)

current_epoch 0 batch 0 optimizer 0 tensor(113.6508, device='cuda:0')
current_epoch 0 batch 0 optimizer 0 tensor(113.6508, device='cuda:0')
current_epoch 0 batch 0 optimizer 0 tensor(113.6508, device='cuda:0')
current_epoch 0 batch 0 optimizer 0 tensor(113.6508, device='cuda:0')
current_epoch 0 batch 0 optimizer 0 tensor(113.6508, device='cuda:0')
current_epoch 0 batch 0 optimizer 0 tensor(113.6508, device='cuda:0')
current_epoch 0 batch 0 optimizer 0 tensor(113.6508, device='cuda:0')
current_epoch 0 batch 0 optimizer 0 tensor(113.6508, device='cuda:0')
current_epoch 0 batch 0 optimizer 0 tensor(113.6508, device='cuda:0')
current_epoch 0 batch 0 optimizer 0 tensor(113.6508, device='cuda:0')
current_epoch 0 batch 0 optimizer 1 tensor(113.6508, device='cuda:0')
current_epoch 0 batch 0 optimizer 1 tensor(113.6360, device='cuda:0')
current_epoch 0 batch 0 optimizer 1 tensor(113.6211, device='cuda:0')
current_epoch 0 batch 0 optimizer 1 tensor(113.6053, device='cuda:0')
current_epoch 0 batch 0 optimizer 1 tensor(113.5885, device='cuda:0')
current_epoch 0 batch 0 optimizer 1 tensor(113.5705, device='cuda:0')
current_epoch 0 batch 0 optimizer 1 tensor(113.5514, device='cuda:0')
current_epoch 0 batch 0 optimizer 1 tensor(113.5310, device='cuda:0')
current_epoch 0 batch 0 optimizer 1 tensor(113.5092, device='cuda:0')
current_epoch 0 batch 0 optimizer 1 tensor(113.4860, device='cuda:0')
current_epoch 0 batch 1 optimizer 0 tensor(113.5134, device='cuda:0')
current_epoch 0 batch 1 optimizer 0 tensor(113.5134, device='cuda:0')
current_epoch 0 batch 1 optimizer 0 tensor(113.5134, device='cuda:0')
current_epoch 0 batch 1 optimizer 0 tensor(113.5134, device='cuda:0')
current_epoch 0 batch 1 optimizer 0 tensor(113.5134, device='cuda:0')
current_epoch 0 batch 1 optimizer 0 tensor(113.5134, device='cuda:0')
current_epoch 0 batch 1 optimizer 0 tensor(113.5134, device='cuda:0')
current_epoch 0 batch 1 optimizer 0 tensor(113.5134, device='cuda:0')
current_epoch 0 batch 1 optimizer 0 tensor(113.5134, device='cuda:0')
current_epoch 0 batch 1 optimizer 0 tensor(113.5134, device='cuda:0')

Part of my confusion stems from my lack of understanding the internals of pytorch.
How does initialization work? It seems like if I can initialize my weights before training, there shouldn’t be any major obstacles preventing me from re-initializing my weights midway through a run (an ensure that my parameters are still differentiable).

UPDATE 2: Turns out that there are gradients being calculated for eta if I try to reset it. But the values of eta are still not changing for some reason.

UPDATE 3: OK this is weird. If I print out the eta values / gradients before and after the optimizer step, it looks like eta is actually being updated - but for some reason, it is being reverted to the previous state in the following step. See the output below.

        if optimizer_i == 0:
            for _ in range(self.hparams.steps_per_batch):
                loss = second_order_closure()
                print('current_epoch', current_epoch,
                      'batch', batch_nb, 'optimizer', optimizer_i, loss.detach())
                print('eta', (self.model.eta.data**2).detach().mean(),
                      (self.model.eta.grad.data**2).detach().mean(), self.model.eta.requires_grad)
                optimizer.step(second_order_closure)
                print('eta', (self.model.eta.data**2).detach().mean(),
                      (self.model.eta.grad.data**2).detach().mean(), self.model.eta.requires_grad)

Output

current_epoch 0 batch 0 optimizer 0 tensor(114.1528, device='cuda:0')
eta tensor(0.4546, device='cuda:0') tensor(0.0002, device='cuda:0') True
eta tensor(0.6156, device='cuda:0') tensor(0.0004, device='cuda:0') True
current_epoch 0 batch 0 optimizer 0 tensor(114.1528, device='cuda:0')
eta tensor(0.4546, device='cuda:0') tensor(0.0007, device='cuda:0') True
eta tensor(0.6089, device='cuda:0') tensor(0.0011, device='cuda:0') True

I’m beginning to wonder if there is some hidden behavior in the closure function in pytorch_lightning. I’m going to cross-post to that issuetracker just to make sure.

Final UPDATE : I think I’m able to fix the problem. It boiled down to better understanding the pytorch-lightning semantics.

First, the second_order_closure function is defined by wrapping the training_step and the backward step (see this line : https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_loop.py#L597)

Second in order to define an initialization step before the forward pass is evaluated, the on_train_batch_start callback needs to be defined (see the callback docs here : https://pytorch-lightning.readthedocs.io/en/latest/callbacks.html)
My callback looked like

    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
        counts = batch.to(self.device)
        self.model.reset(counts)

Finally, this may seem like blasphemy (based on notes regarding inplace operations that I still don’t understand), but I just manipulated the data in eta directly within my reset function. This is defined as follows

    def reset(self, x):
        hx = my_awesome_guess(x)
        self.eta.data = hx.data

Not sure how kosher it is, but from my tests, it looks like all of this works.