Hello. I am trying to export GraphSAGE model of train_sample.py in onnx format. I tried using pytorch.onnx.export with no success. GraphSAGE model’s (in train_sample.py) forward function takes 2 inputs (input graph as DGLblock and input features):
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h
I tried to export model in onnx format as
torch.onnx.export(model, (test_g ,test_g.ndata['features']), "graphsage.onnx", verbose=True, input_names=['input features'], output_names=['output labels'])
This results in following error
RuntimeError: Only tuples, lists and Variables supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type DGLHeteroGraph
I also tried to convert DGLHeteroGraph using to_networkx() but that didn’t help either. Problem is with graph as input. torch.onnx.export() expects input to be in list, variable or tuple form. This is NOT a DGL issue but I am wondering if you guys have some support for simplifying it. Any recommendation on how to export this model into onnx format ? @BarclayII: do you have any trick up in your sleeves ? Thanks in advance.