import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing
class GINConv(MessagePassing):
def __init__(self, nn:Callable, eps: float=0., train_eps: bool=False,
activation="softplus", **kwargs):
super(GINConv, self).__init__(aggr='add', **kwargs)
self.nn = nn # MLP
self.initial_eps = eps
......
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
edge_attr: OptTensor=None, size: Size=None) -> Tensor:
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
......
def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
.......
def __repr__(self):
......
As shown in the above code, after exporting the module as onnx, the inference result of the model is inconsistent with the original model. I don’t know if it’s because some operators in the ‘pyg’ package don’t support exporting as onnx. The specific code line that caused the problem is: out=self. propagate (edgeIndex, x=x, edgeAttr=edgeAttr, size=size)
According to my investigation, the code “out=self. propagate (edgeindex, x=x, edgeattr=edgeattr, size=size)” calls a function called scatter-add_. Does this function not support exporting to onnx?
The following is the onnx diagram of this code: