I am wondering if anyone can refer me to the equivalent of the tf.scatter_update(data, indices, batch)
function in pythorch?
Based on the docs I think torch.scatter
should work unless I misunderstand what the _update
part means in TF1.
The scatter_update
function is given here in the tensorflow repository. What do you think each of inputs of scatter_update
function would be placed as inputs for the torch.scatter
function?
I would assume ref, indices, updates
corresponds to input, (dim,) index, src
in torch.scatter
where in PyTorch you would additionally specify the dimension via dim
.
Do you think that the dim
should be zero
based on the definition of scatter_update
function?
I am getting the following error message, when I use torch.scatter
as a substitute for tf.scatter_update
:
data: tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]), indexes:tensor([0, 1, 2]), ...src: [ 0.42924792 -0.9031867 -0.06928237]
351 print(f"data: {data}, indexes:{indices}, ...src: {batch[:effective_batch_size]}")
--> 352 data= data.scatter(dim=1,index=indices, src=torch.tensor(batch[:effective_batch_size]))
353 print(data)
354 #scatter_update(data, indices, batch[:effective_batch_size])
RuntimeError: Index tensor must have the same number of dimensions as self tensor
Any suggestion?
I don’t fully understand your use case so could you post the desired result tensor?
If you want to assign src
to all rows of data
you could directly index it:
data[:, idx] = src
or just expand
or repeat
src
.
Here is the original tensorflow line and the printed results of the output
for key in transitions._asdict().keys():
data = getattr(self._data, key)
batch = getattr(transitions, key)
tf.print("data is ",data) #added prints here
tf.print("key is ",key)
tf.print("indices are ",indices, "effective batch size: ",effective_batch_size)
tf.print("batch has ", batch)
tf.scatter_update(data, indices, batch[:effective_batch_size])
tf.print(" data after update becomes ",data)
The outputs:
data is [[0 0 0]
[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]
[0 0 0]]
key is s1
indices are [0] effective batch size: 1
batch has [[-0.989759803 0.142742947 0.788506687]]
data after update becomes [[-0.989759803 0.142742947 0.788506687]
[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]
[0 0 0]]
data is [[0 0 0]
[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]
[0 0 0]]
key is s2
indices are [0] effective batch size: 1
batch has [[-0.995679677 0.0928544 1.00487685]]
data after update becomes [[-0.995679677 0.0928544 1.00487685]
[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]
[0 0 0]]
data is [[0]
[0]
[0]
...
[0]
[0]
[0]]
key is a1
indices are [0] effective batch size: 1
batch has [[0.72875309]]
data after update becomes [[0.72875309]
[0]
[0]
...
[0]
[0]
[0]]
data is [[0]
[0]
[0]
...
[0]
[0]
[0]]
key is a2
indices are [0] effective batch size: 1
batch has [[0.613955]]
data after update becomes [[0.613955]
[0]
[0]
...
[0]
[0]
[0]]
data is [0 0 0 ... 0 0 0]
key is discount
indices are [0] effective batch size: 1
batch has [1]
data after update becomes [1 0 0 ... 0 0 0]
data is [0 0 0 ... 0 0 0]
key is reward
indices are [0] effective batch size: 1
batch has [-9.05287075]
data after update becomes [-9.05287075 0 0 ... 0 0 0]
My code in pytorch based on your suggestion:
for key in transitions._asdict().keys():
data = getattr(self._data, key)
batch = getattr(transitions, key)
print(f"data: {data}, indexes:{indices}, ...src: {batch[:effective_batch_size]}")
data= data.scatter(dim=1,index=indices, src=torch.tensor(batch[:effective_batch_size]))
print(data)
I am wondering where I am doing a mistake here?
Based on your example it seems you want to use the index in dim0
in data
and could directly assign batch
as seen here:
data = torch.zeros(7, 3)
batch = torch.tensor([-0.989759803, 0.142742947, 0.788506687])
data[0, :] = batch
print(data)
# tensor([[-0.9898, 0.1427, 0.7885],
# [ 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000]])
Let me know if I misunderstand your examples.