Graph Isomorphism network with edge weights

I am trying to build a GIN model that takes in edge weights with torch_geometric GINEConv, for graph classifcation. But I keep running into issues about how to feed my weights. My data is shaped like so:
Data(x=[7, 4], edge_index=[2, 42], edge_attr=[42, 1], y=[1], num_nodes=7)
the graphs are undirected and there are 4 node features and 1 edge feature.
I tried to make the edge_attr and x input match, and I tried to add edge_dim and to manipulate the edge_attr before feeding it to every convolution but the model either doesn’t work or performs badly.
Here is my code:

class GINE(torch.nn.Module):
“”“GINE”“”
def init(self, num_node_features,dim_h,num_classes,edge_dim=1, epsilone = 1e-4):
super(GINE, self).init()
self.conv1 = (
Sequential(Linear(num_node_features, dim_h),
BatchNorm(dim_h), ReLU(),
Linear(dim_h, dim_h), ReLU()),edge_dim=edge_dim, train_eps=True)

    self.conv2 = GINEConv(
        Sequential(Linear(dim_h, dim_h), BatchNorm(dim_h), ReLU(),
                   Linear(dim_h, dim_h), ReLU()),edge_dim=edge_dim, train_eps=True)
    self.conv3 = GINEConv(
        Sequential(Linear(dim_h, dim_h), BatchNorm(dim_h), ReLU(),
                   Linear(dim_h, dim_h), ReLU()),edge_dim=edge_dim, train_eps=True
    self.lin1 = Linear(dim_h*3, dim_h*3)
    self.lin2 = Linear(dim_h*3, num_classes)
    self.epsilone = epsilone

def forward(self, x, edge_index, batch, edge_attr):
    
    # Node embeddings 
    #lin = Linear(1, 4)
    #edge_attr=lin(edge_attr)
    h1 = self.conv1(x, edge_index,edge_attr)
    
    
    
    #lin1 = Linear(4, 64)
    
    #edge_attr=lin1(edge_attr)
    
    h2 = self.conv2(h1, edge_index,edge_attr)
    
    
    h3 = self.conv3(h2, edge_index,edge_attr)

    # Graph-level readout
    h1 = global_add_pool(h1, batch)
    h2 = global_add_pool(h2, batch)
    h3 = global_add_pool(h3, batch)

    # Concatenate graph embeddings
    h = torch.cat((h1, h2, h3), dim=1)

    # Classifier
    h = self.lin1(h)
    h = h.relu()
    h = F.dropout(h, p=self.epsilone)
    h = self.lin2(h)
   
    return h, F.log_softmax(h, dim=1)