Tf.scatter_update in pytorch

I am wondering if anyone can refer me to the equivalent of the tf.scatter_update(data, indices, batch) function in pythorch?

Based on the docs I think torch.scatter should work unless I misunderstand what the _update part means in TF1.

The scatter_update function is given here in the tensorflow repository. What do you think each of inputs of scatter_update function would be placed as inputs for the torch.scatter function?

I would assume ref, indices, updates corresponds to input, (dim,) index, src in torch.scatter where in PyTorch you would additionally specify the dimension via dim.

Do you think that the dim should be zero based on the definition of scatter_update function?

I am getting the following error message, when I use torch.scatter as a substitute for tf.scatter_update:

data: tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]), indexes:tensor([0, 1, 2]), ...src: [ 0.42924792 -0.9031867  -0.06928237]

    351         print(f"data: {data}, indexes:{indices}, ...src: {batch[:effective_batch_size]}")
--> 352         data= data.scatter(dim=1,index=indices, src=torch.tensor(batch[:effective_batch_size]))
    353         print(data)
    354         #scatter_update(data, indices, batch[:effective_batch_size])

RuntimeError: Index tensor must have the same number of dimensions as self tensor

Any suggestion?

I don’t fully understand your use case so could you post the desired result tensor?
If you want to assign src to all rows of data you could directly index it:

data[:, idx] = src

or just expand or repeat src.

Here is the original tensorflow line and the printed results of the output

    for key in transitions._asdict().keys():
      data = getattr(self._data, key)
      batch = getattr(transitions, key)
      tf.print("data is ",data) #added prints here
      tf.print("key is ",key)
      tf.print("indices are ",indices, "effective batch size: ",effective_batch_size)
      tf.print("batch has ", batch)
      tf.scatter_update(data, indices, batch[:effective_batch_size])
      tf.print(" data after update becomes ",data)

The outputs:


data is  [[0 0 0]
 [0 0 0]
 [0 0 0]
 ...
 [0 0 0]
 [0 0 0]
 [0 0 0]]
key is  s1
indices are  [0] effective batch size:  1
batch has  [[-0.989759803 0.142742947 0.788506687]]
 data after update becomes  [[-0.989759803 0.142742947 0.788506687]
 [0 0 0]
 [0 0 0]
 ...
 [0 0 0]
 [0 0 0]
 [0 0 0]]
data is  [[0 0 0]
 [0 0 0]
 [0 0 0]
 ...
 [0 0 0]
 [0 0 0]
 [0 0 0]]
key is  s2
indices are  [0] effective batch size:  1
batch has  [[-0.995679677 0.0928544 1.00487685]]
 data after update becomes  [[-0.995679677 0.0928544 1.00487685]
 [0 0 0]
 [0 0 0]
 ...
 [0 0 0]
 [0 0 0]
 [0 0 0]]
data is  [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]]
key is  a1
indices are  [0] effective batch size:  1
batch has  [[0.72875309]]
 data after update becomes  [[0.72875309]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]]
data is  [[0]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]]
key is  a2
indices are  [0] effective batch size:  1
batch has  [[0.613955]]
 data after update becomes  [[0.613955]
 [0]
 [0]
 ...
 [0]
 [0]
 [0]]
data is  [0 0 0 ... 0 0 0]
key is  discount
indices are  [0] effective batch size:  1
batch has  [1]
 data after update becomes  [1 0 0 ... 0 0 0]
data is  [0 0 0 ... 0 0 0]
key is  reward
indices are  [0] effective batch size:  1
batch has  [-9.05287075]
 data after update becomes  [-9.05287075 0 0 ... 0 0 0]

My code in pytorch based on your suggestion:

    for key in transitions._asdict().keys():
        data = getattr(self._data, key)
        batch = getattr(transitions, key)
        print(f"data: {data}, indexes:{indices}, ...src: {batch[:effective_batch_size]}")
        data= data.scatter(dim=1,index=indices, src=torch.tensor(batch[:effective_batch_size]))
        print(data)

I am wondering where I am doing a mistake here?

Based on your example it seems you want to use the index in dim0 in data and could directly assign batch as seen here:

data = torch.zeros(7, 3)
batch = torch.tensor([-0.989759803, 0.142742947, 0.788506687])
data[0, :] = batch
print(data)
# tensor([[-0.9898,  0.1427,  0.7885],
#         [ 0.0000,  0.0000,  0.0000],
#         [ 0.0000,  0.0000,  0.0000],
#         [ 0.0000,  0.0000,  0.0000],
#         [ 0.0000,  0.0000,  0.0000],
#         [ 0.0000,  0.0000,  0.0000],
#         [ 0.0000,  0.0000,  0.0000]])

Let me know if I misunderstand your examples.