I am wondering if anyone can refer me to the equivalent of the `tf.scatter_update(data, indices, batch)`

function in pythorch?

Based on the docs I think `torch.scatter`

should work unless I misunderstand what the `_update`

part means in TF1.

The `scatter_update`

function is given here in the tensorflow repository. What do you think each of inputs of `scatter_update`

function would be placed as inputs for the `torch.scatter`

function?

I would assume `ref, indices, updates`

corresponds to `input, (dim,) index, src`

in `torch.scatter`

where in PyTorch you would additionally specify the dimension via `dim`

.

Do you think that the `dim`

should be `zero`

based on the definition of `scatter_update`

function?

I am getting the following error message, when I use `torch.scatter`

as a substitute for `tf.scatter_update`

:

```
data: tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]), indexes:tensor([0, 1, 2]), ...src: [ 0.42924792 -0.9031867 -0.06928237]
351 print(f"data: {data}, indexes:{indices}, ...src: {batch[:effective_batch_size]}")
--> 352 data= data.scatter(dim=1,index=indices, src=torch.tensor(batch[:effective_batch_size]))
353 print(data)
354 #scatter_update(data, indices, batch[:effective_batch_size])
RuntimeError: Index tensor must have the same number of dimensions as self tensor
```

Any suggestion?

I donâ€™t fully understand your use case so could you post the desired result tensor?

If you want to assign `src`

to all rows of `data`

you could directly index it:

```
data[:, idx] = src
```

or just `expand`

or `repeat`

`src`

.

Here is the original tensorflow line and the printed results of the output

```
for key in transitions._asdict().keys():
data = getattr(self._data, key)
batch = getattr(transitions, key)
tf.print("data is ",data) #added prints here
tf.print("key is ",key)
tf.print("indices are ",indices, "effective batch size: ",effective_batch_size)
tf.print("batch has ", batch)
tf.scatter_update(data, indices, batch[:effective_batch_size])
tf.print(" data after update becomes ",data)
```

The outputs:

```
data is [[0 0 0]
[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]
[0 0 0]]
key is s1
indices are [0] effective batch size: 1
batch has [[-0.989759803 0.142742947 0.788506687]]
data after update becomes [[-0.989759803 0.142742947 0.788506687]
[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]
[0 0 0]]
data is [[0 0 0]
[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]
[0 0 0]]
key is s2
indices are [0] effective batch size: 1
batch has [[-0.995679677 0.0928544 1.00487685]]
data after update becomes [[-0.995679677 0.0928544 1.00487685]
[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]
[0 0 0]]
data is [[0]
[0]
[0]
...
[0]
[0]
[0]]
key is a1
indices are [0] effective batch size: 1
batch has [[0.72875309]]
data after update becomes [[0.72875309]
[0]
[0]
...
[0]
[0]
[0]]
data is [[0]
[0]
[0]
...
[0]
[0]
[0]]
key is a2
indices are [0] effective batch size: 1
batch has [[0.613955]]
data after update becomes [[0.613955]
[0]
[0]
...
[0]
[0]
[0]]
data is [0 0 0 ... 0 0 0]
key is discount
indices are [0] effective batch size: 1
batch has [1]
data after update becomes [1 0 0 ... 0 0 0]
data is [0 0 0 ... 0 0 0]
key is reward
indices are [0] effective batch size: 1
batch has [-9.05287075]
data after update becomes [-9.05287075 0 0 ... 0 0 0]
```

My code in pytorch based on your suggestion:

```
for key in transitions._asdict().keys():
data = getattr(self._data, key)
batch = getattr(transitions, key)
print(f"data: {data}, indexes:{indices}, ...src: {batch[:effective_batch_size]}")
data= data.scatter(dim=1,index=indices, src=torch.tensor(batch[:effective_batch_size]))
print(data)
```

I am wondering where I am doing a mistake here?

Based on your example it seems you want to use the index in `dim0`

in `data`

and could directly assign `batch`

as seen here:

```
data = torch.zeros(7, 3)
batch = torch.tensor([-0.989759803, 0.142742947, 0.788506687])
data[0, :] = batch
print(data)
# tensor([[-0.9898, 0.1427, 0.7885],
# [ 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000]])
```

Let me know if I misunderstand your examples.