let us consider one example, suppose we have a last word prediction task,

```
# I ate an apple
# 0 1 2 3
# I went to park
# 0 4 5 6
# I slept all day
# 0 7 8 9
```

so, our input to model will be like,

```
input = torch.LongTensor([[0, 1, 2],
[0, 4, 5],
[0, 7, 8]])
```

and target would be like,

```
target = torch.LongTensor([[3],
[6],
[9]])
# means if model sees 'I', 'ate', 'an' then it should predict 'apple'
# for 'I', 'went', 'to' predict 'park' and so on
```

now we create our dataset,

```
tensor_dataset = torch.utils.data.TensorDataset(input, target)
list(tensor_dataset)
```

[(tensor([0, 1, 2]), tensor([3])),

(tensor([0, 4, 5]), tensor([6])),

(tensor([0, 7, 8]), tensor([9]))]

let us consider a batch size of 2, so in the first batch, first two sentences would be considered, in the second batch, third sentence would be considered.

```
dataset = torch.utils.data.DataLoader(tensor_dataset, batch_size=2)
```

this is what our dataset looks like

```
for i, (input, target) in enumerate(dataset):
print('i', i, '\ninput', input, '\ntarget', target)
```

```
i 0
input tensor([[0, 1, 2],
[0, 4, 5]])
target tensor([[3],
[6]])
i 1
input tensor([[0, 7, 8]])
target tensor([[9]])
```

now, we create our model,

```
class Net(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(10, 20, sparse=True) # our vocabulary has 10 words, we set embedding size as 20
# we use sparse=True, so that embeddings of words that are not in a batch would not be updated
self.transformer_encoder_layer = nn.TransformerEncoderLayer(20, 2)
self.transformer_encoder = nn.TransformerEncoder(self.transformer_encoder_layer, 1)
self.lin = nn.Linear(3*20, 10) # we get a probability distribution for all 10 words
self.softmax = nn.Softmax(dim=-1) # just to print probabilities
def forward(self, input):
embedded_input = self.embed(input)
print(input.size(0), input.size(1))
print('embedded_input.shape', embedded_input.shape)
out = self.transformer_encoder(embedded_input)
out = out.reshape(out.size(0), out.size(1)*out.size(2))
out = self.lin(out)
print('out.shape', out.shape)
print(self.softmax(out))
return out
```

now, we do our training,

```
net = Net()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
loss_plot = []
for num_epochs in range(10):
for i, (input, target) in enumerate(dataset):
optimizer.zero_grad()
loss = loss_fn(net(input), target.view(-1))
loss_plot.append(loss)
loss.backward()
for name, param in net.named_parameters():
if 'embed' in name:
print(name, param.grad)
# print(list(net.parameters()))
optimizer.step()
```

output of print statements looks something like this,

```
2 3
embedded_input.shape torch.Size([2, 3, 20])
out.shape torch.Size([2, 10])
tensor([[0.0136, 0.0095, 0.0165, 0.7450, 0.0182, 0.0222, 0.0389, 0.0225, 0.0256,
0.0880],
[0.0246, 0.0211, 0.0162, 0.0714, 0.0065, 0.0115, 0.7557, 0.0202, 0.0166,
0.0562]], grad_fn=<SoftmaxBackward>)
embed.weight tensor(indices=tensor([[0, 1, 2, 0, 4, 5]]),
values=tensor([[-1.2932e-02, 1.0104e-03, -3.9477e-03, 3.8568e-03,
-9.8572e-03, 1.3827e-03, -1.4141e-02, 7.4199e-03,
9.7238e-04, -6.7569e-03, 6.8323e-03, 8.7705e-03,
-6.8232e-04, 6.6810e-03, 8.1223e-03, -1.0538e-03,
-1.1698e-02, 8.6740e-03, 1.7381e-03, -5.2466e-03],
[ 6.2502e-05, -3.4446e-03, -8.7784e-03, -2.4090e-03,
3.3907e-03, 1.0026e-04, 4.1565e-03, -4.5490e-03,
5.8510e-03, -5.8728e-03, -7.7380e-03, -4.1700e-03,
-3.1775e-03, 6.2709e-03, 4.3878e-03, 1.1901e-03,
-9.7035e-04, 4.6694e-04, 1.4419e-03, 7.2143e-03],
[ 2.1315e-03, -2.1100e-03, -4.7602e-03, -3.3467e-03,
-3.1354e-03, 6.1653e-03, -1.1522e-02, 2.0889e-03,
-4.1310e-03, 1.0974e-03, -3.4026e-03, -3.2140e-03,
-6.5089e-03, 3.2289e-03, 8.5952e-04, 3.0189e-03,
5.5930e-03, 7.7494e-04, 4.9238e-03, 7.5434e-03],
[ 1.6089e-03, -1.5056e-02, 1.1222e-02, 3.6189e-03,
4.0124e-03, -6.9139e-03, -1.7996e-03, 2.3435e-03,
2.9756e-03, 4.3685e-03, -1.4790e-02, 2.6396e-03,
-4.5901e-03, -4.5700e-03, -5.2482e-03, 3.9613e-03,
4.8062e-03, -4.5896e-03, 1.2451e-02, -7.3047e-03],
[-7.6837e-03, 7.0184e-03, 5.9852e-03, -2.7696e-03,
-1.2842e-02, -2.8121e-03, -3.3888e-03, 7.1801e-04,
-1.1498e-02, 1.0984e-03, -1.7498e-03, 3.7961e-03,
1.8652e-03, 7.1568e-03, -4.6547e-03, 3.0268e-03,
9.9045e-03, -3.4706e-03, -3.0367e-03, 5.5819e-03],
[-1.8114e-02, -6.5709e-03, 5.3046e-03, 5.8249e-03,
9.4378e-04, 1.2052e-02, 9.1222e-03, 4.8452e-04,
-7.5793e-03, -3.2206e-03, 6.3282e-03, -1.7706e-02,
1.6900e-02, -2.1886e-03, 6.7194e-04, -4.2445e-03,
-3.8267e-04, 7.8978e-03, -9.4713e-03, -7.2790e-03]]),
size=(10, 20), nnz=6, layout=torch.sparse_coo)
1 3
embedded_input.shape torch.Size([1, 3, 20])
out.shape torch.Size([1, 10])
tensor([[0.0180, 0.0087, 0.0069, 0.0273, 0.0210, 0.0133, 0.0216, 0.0091, 0.0049,
0.8693]], grad_fn=<SoftmaxBackward>)
embed.weight tensor(indices=tensor([[0, 7, 8]]),
values=tensor([[ 0.0210, 0.0100, -0.0089, -0.0045, 0.0082, -0.0014,
0.0132, -0.0120, -0.0027, -0.0058, 0.0039, -0.0031,
-0.0098, -0.0070, -0.0011, 0.0032, -0.0043, 0.0151,
0.0008, -0.0038],
[ 0.0050, 0.0010, -0.0012, -0.0027, -0.0032, 0.0006,
0.0040, -0.0026, -0.0033, 0.0074, -0.0001, -0.0050,
0.0011, 0.0009, 0.0014, 0.0017, -0.0023, -0.0074,
0.0033, 0.0004],
[ 0.0003, 0.0092, 0.0041, 0.0042, 0.0023, -0.0032,
0.0008, -0.0046, -0.0023, 0.0003, -0.0032, 0.0084,
-0.0021, -0.0017, -0.0119, 0.0003, -0.0076, -0.0019,
0.0012, 0.0016]]),
size=(10, 20), nnz=3, layout=torch.sparse_coo)
```

here, we update embeddings of only inputs that are passed to our model, and not all the inputs, we have batch size of 2, so when we pass,

```
I ate an
I went to
```

so embeddings of these words gets updated and embeddings of ‘slept’, ‘all’, ‘day’, ‘park’, ‘apple’ do not get updated.

when we pass

```
I slept all
```

as input, then embeddings of ‘ate’, ‘an’, ‘went’, ‘to’, ‘day’, ‘park’, ‘apple’ do not get updated.

if we want to update different embeddings based on what model predicted, that is, if model predicts

```
I ate an park
```

then we would want to update embeddings of some words, and if model predicted

```
I ate an apple
```

then we would want to update embeddings of some different words, then we would have to change this sparse tensor’s requires_grad of some indices, after model prediction, which currently I do not know how to do.