Hi,
I have pasted in a code segment below that throws a dgl error if I use the builtin functions, whereas it runs perfectly, when I use equivalent functions coded by me in place of the built-in functions.
Any idea what is going on here? Why do I get an error if I use the builtin functions? I certainly cannot find any limitations on node feature tensor types in the documentation when builtin functions…
Thanks!
import numpy as np
import torch
import dgl
from dgl import DGLGraph
import dgl.function as fn
g = dgl.DGLGraph()
g.add_nodes(4)
g.ndata['x'] = torch.DoubleTensor([1.,2.,2.,7.])
g.ndata['y'] = torch.DoubleTensor([0.]*4)
g.add_edges([0, 1, 1, 2], [1, 2, 3, 3])
g.edata ["ab"] = torch.DoubleTensor ([11.,10.,9., 8.])
def make_e_add_v (e_feat, dst_feat, msg_feat):
def e_add_v (eb):
# set_trace ()
return { msg_feat: eb.data[e_feat] + eb.dst[dst_feat] }
return e_add_v
def make_sum_reduce (msg_feat, n_feat):
def sum_reduce (nb):
return {n_feat : nb.mailbox [msg_feat].sum (dim=1)}
return sum_reduce
the following code works:
g.update_all (message_func=make_e_add_v ("ab", "x", "m"), reduce_func=make_sum_reduce ("m","y"))
the next line is theoretically equivalent with the previous one, but it gives a an error:
g.update_all (message_func=make_e_add_v ("ab", "x", "m"), reduce_func=fn.sum ("m","y"))
# same problem here:
g.update_all (message_func=fn.e_add_v ("ab", "x", "m"), reduce_func=fn.sum ("m","y"))
the end of the error:
raise DGLError(py_str(_LIB.DGLGetLastError()))
dgl._ffi.base.DGLError: [14:56:03] /opt/dgl/src/kernel/cpu/../binary_reduce_impl.h:104: Unsupported dtype: _@