Different prediction results, update different weights in the model

Suppose there are parts A and part B in a model.

What I want to achieve is:

  • If the model prediction is a, then update the weight of part A in the model.

  • If the model prediction is b, then update the weight of part B in the model.

My question is:

  1. How is it implemented in PyTorch? (Especially when batch prediction)
  2. From the perspective of deep learning, is such an update scheme practical?

do you mean something like this,

import torch.nn as nn, torch

class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.A = nn.ModuleDict([['lin', nn.Linear(3, 3)], 
                             ['dropout', nn.Dropout(0.5)]])
    self.B = nn.Linear(3, 3)
  def forward(self, x):
    out = self.B(x)
    out = self.A['lin'](out)
    out = self.A['dropout'](out)
    return out

net = Net()
optimizer1 = torch.optim.SGD(net.A.parameters(), lr=0.01)
optimizer2 = torch.optim.SGD(net.B.parameters(), lr=0.01)
for i in range(2):
  optimizer1.zero_grad()
  optimizer2.zero_grad()
  input = torch.randn(3, 3)
  loss = (net(input)).sum()
  loss.backward()
  if loss < 1: # replace with net(input) == prediction
    optimizer1.step()
  else:
    optimizer2.step()

or with one optimizer

optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
for i in range(2):
  optimizer.zero_grad()
  input = torch.randn(3, 3)
  loss = (net(input)).sum()
  print(loss)
  loss.backward()
  if loss < 1: # replace with net(input) == prediction
    net.A.lin.requires_grad_(False)
    net.B.requires_grad_(True)
  else:
    net.A.lin.requires_grad_(True)
    net.B.requires_grad_(False)
  print(list(net.parameters()))
  optimizer.step()
1 Like

Thanks for your reply.:grin::grin:
The code you wrote is very good.:tada::tada::tada:

But I have a question: When the optimizer is BGD, MBGD, can the above code run normally?:thinking::thinking:

In the batches, there will be samples predicted as a, and samples predicted as b.
I don’t think optimizer.step() can run normally at that time.:thinking::thinking:

but if we make predictions for a batch, then optimizer.step() would only update parameters once, for example, if our predictions were something like [a, b, a, a, b], and then we do optimizer.step(), then it would update parameters of our neural network once, if we want to update after every prediction model makes, then we will have to do optimizer.step() after every prediction, so model predicts something like [a], we do optimizer.step() on some parameters, model predicts [b], we do optimizer.step() on some different parameters.

So… that ‘Different prediction results, update different weights in the model’ can be implemented only by setting SGD as the optimizer? Since predictions like [a, b, a, a, b] would update different part of model (part A or part B) due to the different elements of that.

Is my understanding right?:thinking::thinking:

let us consider one example, suppose we have a last word prediction task,

# I ate an apple
# 0  1  2   3

# I went to park
# 0  4   5   6

# I slept all day
# 0   7    8   9

so, our input to model will be like,

input = torch.LongTensor([[0, 1, 2], 
                      [0, 4, 5], 
                      [0, 7, 8]])

and target would be like,

target = torch.LongTensor([[3], 
                       [6], 
                       [9]])
# means if model sees 'I', 'ate', 'an' then it should predict 'apple'
# for 'I', 'went', 'to' predict 'park' and so on

now we create our dataset,

tensor_dataset = torch.utils.data.TensorDataset(input, target)
list(tensor_dataset)

[(tensor([0, 1, 2]), tensor([3])),
(tensor([0, 4, 5]), tensor([6])),
(tensor([0, 7, 8]), tensor([9]))]

let us consider a batch size of 2, so in the first batch, first two sentences would be considered, in the second batch, third sentence would be considered.

dataset = torch.utils.data.DataLoader(tensor_dataset, batch_size=2)

this is what our dataset looks like

for i, (input, target) in enumerate(dataset):
  print('i', i, '\ninput', input, '\ntarget', target)
i 0 
input tensor([[0, 1, 2],
        [0, 4, 5]]) 
target tensor([[3],
        [6]])
i 1 
input tensor([[0, 7, 8]]) 
target tensor([[9]])

now, we create our model,

class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.embed = nn.Embedding(10, 20, sparse=True) # our vocabulary has 10 words, we set embedding size as 20
    # we use sparse=True, so that embeddings of words that are not in a batch would not be updated
    self.transformer_encoder_layer = nn.TransformerEncoderLayer(20, 2)
    self.transformer_encoder = nn.TransformerEncoder(self.transformer_encoder_layer, 1)
    self.lin = nn.Linear(3*20, 10) # we get a probability distribution for all 10 words
    self.softmax = nn.Softmax(dim=-1) # just to print probabilities
  def forward(self, input):
    embedded_input = self.embed(input)
    print(input.size(0), input.size(1))
    print('embedded_input.shape', embedded_input.shape)
    out = self.transformer_encoder(embedded_input)
    out = out.reshape(out.size(0), out.size(1)*out.size(2))
    out = self.lin(out)
    print('out.shape', out.shape)
    print(self.softmax(out))
    return out

now, we do our training,

net = Net()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
loss_plot = []

for num_epochs in range(10):
  for i, (input, target) in enumerate(dataset):
    optimizer.zero_grad()
    loss = loss_fn(net(input), target.view(-1))
    loss_plot.append(loss)
    loss.backward()
    for name, param in net.named_parameters():
      if 'embed' in name:
        print(name, param.grad)
    # print(list(net.parameters()))
    optimizer.step()

output of print statements looks something like this,

2 3
embedded_input.shape torch.Size([2, 3, 20])
out.shape torch.Size([2, 10])
tensor([[0.0136, 0.0095, 0.0165, 0.7450, 0.0182, 0.0222, 0.0389, 0.0225, 0.0256,
         0.0880],
        [0.0246, 0.0211, 0.0162, 0.0714, 0.0065, 0.0115, 0.7557, 0.0202, 0.0166,
         0.0562]], grad_fn=<SoftmaxBackward>)
embed.weight tensor(indices=tensor([[0, 1, 2, 0, 4, 5]]),
       values=tensor([[-1.2932e-02,  1.0104e-03, -3.9477e-03,  3.8568e-03,
                       -9.8572e-03,  1.3827e-03, -1.4141e-02,  7.4199e-03,
                        9.7238e-04, -6.7569e-03,  6.8323e-03,  8.7705e-03,
                       -6.8232e-04,  6.6810e-03,  8.1223e-03, -1.0538e-03,
                       -1.1698e-02,  8.6740e-03,  1.7381e-03, -5.2466e-03],
                      [ 6.2502e-05, -3.4446e-03, -8.7784e-03, -2.4090e-03,
                        3.3907e-03,  1.0026e-04,  4.1565e-03, -4.5490e-03,
                        5.8510e-03, -5.8728e-03, -7.7380e-03, -4.1700e-03,
                       -3.1775e-03,  6.2709e-03,  4.3878e-03,  1.1901e-03,
                       -9.7035e-04,  4.6694e-04,  1.4419e-03,  7.2143e-03],
                      [ 2.1315e-03, -2.1100e-03, -4.7602e-03, -3.3467e-03,
                       -3.1354e-03,  6.1653e-03, -1.1522e-02,  2.0889e-03,
                       -4.1310e-03,  1.0974e-03, -3.4026e-03, -3.2140e-03,
                       -6.5089e-03,  3.2289e-03,  8.5952e-04,  3.0189e-03,
                        5.5930e-03,  7.7494e-04,  4.9238e-03,  7.5434e-03],
                      [ 1.6089e-03, -1.5056e-02,  1.1222e-02,  3.6189e-03,
                        4.0124e-03, -6.9139e-03, -1.7996e-03,  2.3435e-03,
                        2.9756e-03,  4.3685e-03, -1.4790e-02,  2.6396e-03,
                       -4.5901e-03, -4.5700e-03, -5.2482e-03,  3.9613e-03,
                        4.8062e-03, -4.5896e-03,  1.2451e-02, -7.3047e-03],
                      [-7.6837e-03,  7.0184e-03,  5.9852e-03, -2.7696e-03,
                       -1.2842e-02, -2.8121e-03, -3.3888e-03,  7.1801e-04,
                       -1.1498e-02,  1.0984e-03, -1.7498e-03,  3.7961e-03,
                        1.8652e-03,  7.1568e-03, -4.6547e-03,  3.0268e-03,
                        9.9045e-03, -3.4706e-03, -3.0367e-03,  5.5819e-03],
                      [-1.8114e-02, -6.5709e-03,  5.3046e-03,  5.8249e-03,
                        9.4378e-04,  1.2052e-02,  9.1222e-03,  4.8452e-04,
                       -7.5793e-03, -3.2206e-03,  6.3282e-03, -1.7706e-02,
                        1.6900e-02, -2.1886e-03,  6.7194e-04, -4.2445e-03,
                       -3.8267e-04,  7.8978e-03, -9.4713e-03, -7.2790e-03]]),
       size=(10, 20), nnz=6, layout=torch.sparse_coo)
1 3
embedded_input.shape torch.Size([1, 3, 20])
out.shape torch.Size([1, 10])
tensor([[0.0180, 0.0087, 0.0069, 0.0273, 0.0210, 0.0133, 0.0216, 0.0091, 0.0049,
         0.8693]], grad_fn=<SoftmaxBackward>)
embed.weight tensor(indices=tensor([[0, 7, 8]]),
       values=tensor([[ 0.0210,  0.0100, -0.0089, -0.0045,  0.0082, -0.0014,
                        0.0132, -0.0120, -0.0027, -0.0058,  0.0039, -0.0031,
                       -0.0098, -0.0070, -0.0011,  0.0032, -0.0043,  0.0151,
                        0.0008, -0.0038],
                      [ 0.0050,  0.0010, -0.0012, -0.0027, -0.0032,  0.0006,
                        0.0040, -0.0026, -0.0033,  0.0074, -0.0001, -0.0050,
                        0.0011,  0.0009,  0.0014,  0.0017, -0.0023, -0.0074,
                        0.0033,  0.0004],
                      [ 0.0003,  0.0092,  0.0041,  0.0042,  0.0023, -0.0032,
                        0.0008, -0.0046, -0.0023,  0.0003, -0.0032,  0.0084,
                       -0.0021, -0.0017, -0.0119,  0.0003, -0.0076, -0.0019,
                        0.0012,  0.0016]]),
       size=(10, 20), nnz=3, layout=torch.sparse_coo)

here, we update embeddings of only inputs that are passed to our model, and not all the inputs, we have batch size of 2, so when we pass,

I ate an 
I went to

so embeddings of these words gets updated and embeddings of ‘slept’, ‘all’, ‘day’, ‘park’, ‘apple’ do not get updated.

when we pass

I slept all

as input, then embeddings of ‘ate’, ‘an’, ‘went’, ‘to’, ‘day’, ‘park’, ‘apple’ do not get updated.

if we want to update different embeddings based on what model predicted, that is, if model predicts

I ate an park

then we would want to update embeddings of some words, and if model predicted

I ate an apple

then we would want to update embeddings of some different words, then we would have to change this sparse tensor’s requires_grad of some indices, after model prediction, which currently I do not know how to do.

1 Like

Thank you so much… :sob: :sob: