Replay memory for Graph Data! TypeError: expected Tensor as element 0 in argument 0, but got Data

I have this issue when I am trying to use optimize_model() in my DQN training process for torch graph data type.

Could you please explain how should I use this function for my torch graph data?

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    #print("Batch is here:", batch)
    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)


    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)    
    non_final_next_states =  torch.cat(([s for s in batch.next_state if s is not None]))
    
    state_batch =  torch.cat(batch.state)
    action_batch =  torch.cat(batch.action)
    reward_batch =  torch.cat(batch.reward)


    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch.x, state_batch.edge_index, state_batch.edge_type, state_batch.edge_attr).gather(1, action_batch)

I also tried to the resolve issue using splitting graph data to different components (please check below) but still I have AssertionError!

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    #print("Batch is here:", batch)
    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)


    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)    

    non_final_next_states_x_s = torch.cat(([s.x for s in batch.next_state if s is not None]))
    non_final_next_states_edge_index_s = torch.cat(([s.edge_index for s in batch.next_state if s is not None])) 
    non_final_next_states_edge_type_s = torch.cat(([s.edge_type for s in batch.next_state if s is not None])) 
    non_final_next_states_edge_attr_s = torch.cat(([s.edge_attr for s in batch.next_state if s is not None])) 

    x_s = torch.cat(([s.x for s in batch.state]))
    edge_index_s = torch.cat(([s.edge_index for s in batch.state])) 
    edge_type_s = torch.cat(([s.edge_type for s in batch.state])) 
    edge_attr_s = torch.cat(([s.edge_attr for s in batch.state]))

    action_batch =  torch.tensor(batch.action)
    reward_batch =  torch.tensor(batch.reward)
    

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(x_s, edge_index_s, edge_type_s, edge_attr_s).gather(1, action_batch)

Thanks for reporting this. Can I ask you a bit more context, e.g. a minimally reproducible example, the exact error message and what you are trying to achieve?
From a high level perspective you are trying to use a replay buffer that consumes tensors with a data structure that is not a tensor.
If you’d like torchrl provides replay buffers that do not assume any data type. Have a look at the demo for some pointers.

Happy to help further once we get some more info