Could you please explain how should I use this function for my torch graph data?
def optimize_model():
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
# detailed explanation). This converts batch-array of Transitions
# to Transition of batch-arrays.
batch = Transition(*zip(*transitions))
#print("Batch is here:", batch)
# Compute a mask of non-final states and concatenate the batch elements
# (a final state would've been the one after which simulation ended)
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
non_final_next_states = torch.cat(([s for s in batch.next_state if s is not None]))
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch.x, state_batch.edge_index, state_batch.edge_type, state_batch.edge_attr).gather(1, action_batch)
I also tried to the resolve issue using splitting graph data to different components (please check below) but still I have AssertionError!
def optimize_model():
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
# detailed explanation). This converts batch-array of Transitions
# to Transition of batch-arrays.
batch = Transition(*zip(*transitions))
#print("Batch is here:", batch)
# Compute a mask of non-final states and concatenate the batch elements
# (a final state would've been the one after which simulation ended)
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
non_final_next_states_x_s = torch.cat(([s.x for s in batch.next_state if s is not None]))
non_final_next_states_edge_index_s = torch.cat(([s.edge_index for s in batch.next_state if s is not None]))
non_final_next_states_edge_type_s = torch.cat(([s.edge_type for s in batch.next_state if s is not None]))
non_final_next_states_edge_attr_s = torch.cat(([s.edge_attr for s in batch.next_state if s is not None]))
x_s = torch.cat(([s.x for s in batch.state]))
edge_index_s = torch.cat(([s.edge_index for s in batch.state]))
edge_type_s = torch.cat(([s.edge_type for s in batch.state]))
edge_attr_s = torch.cat(([s.edge_attr for s in batch.state]))
action_batch = torch.tensor(batch.action)
reward_batch = torch.tensor(batch.reward)
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(x_s, edge_index_s, edge_type_s, edge_attr_s).gather(1, action_batch)
Thanks for reporting this. Can I ask you a bit more context, e.g. a minimally reproducible example, the exact error message and what you are trying to achieve?
From a high level perspective you are trying to use a replay buffer that consumes tensors with a data structure that is not a tensor.
If you’d like torchrl provides replay buffers that do not assume any data type. Have a look at the demo for some pointers.