Builtin functions give "Unsupported dtype" error on torch.double tensors

Hi,

I have pasted in a code segment below that throws a dgl error if I use the builtin functions, whereas it runs perfectly, when I use equivalent functions coded by me in place of the built-in functions.

Any idea what is going on here? Why do I get an error if I use the builtin functions? I certainly cannot find any limitations on node feature tensor types in the documentation when builtin functions…

Thanks!

import numpy as np
import torch
import dgl
from dgl import DGLGraph
import dgl.function as fn

g = dgl.DGLGraph()
g.add_nodes(4)
g.ndata['x'] = torch.DoubleTensor([1.,2.,2.,7.])
g.ndata['y'] = torch.DoubleTensor([0.]*4)
g.add_edges([0, 1, 1, 2], [1, 2, 3, 3])
g.edata ["ab"] = torch.DoubleTensor ([11.,10.,9., 8.])


def make_e_add_v (e_feat, dst_feat, msg_feat):
    def e_add_v (eb):
        # set_trace ()
        return { msg_feat: eb.data[e_feat] + eb.dst[dst_feat] }
    return e_add_v

def make_sum_reduce (msg_feat, n_feat):
    def sum_reduce (nb):
        return {n_feat : nb.mailbox [msg_feat].sum (dim=1)}
    return sum_reduce

the following code works:

g.update_all (message_func=make_e_add_v ("ab", "x", "m"), reduce_func=make_sum_reduce ("m","y"))

the next line is theoretically equivalent with the previous one, but it gives a an error:

g.update_all (message_func=make_e_add_v ("ab", "x", "m"), reduce_func=fn.sum ("m","y"))
# same problem here:
g.update_all (message_func=fn.e_add_v ("ab", "x", "m"), reduce_func=fn.sum ("m","y"))

the end of the error:

raise DGLError(py_str(_LIB.DGLGetLastError()))
dgl._ffi.base.DGLError: [14:56:03] /opt/dgl/src/kernel/cpu/../binary_reduce_impl.h:104: Unsupported dtype: _@

Thanks for your feedback.


at current stage we only support float, however, you can hack the code to:

#define DGL_DTYPE_SWITCH(val, DType, ...)                   \
   if (val.code == kDLFloat && (val.bits == 32 || val_bits == 64)) {             \
     typedef float DType;                                    \
     {__VA_ARGS__}                                           \
   } else {                                                  \
     LOG(FATAL) << "Unsupported dtype: " << val.code << "_"  \
                << val.bits;                                 \
   }

and build the repo from source if you want to support double, but the compilation may take a lot of time.

For most GNN applications, the precision of float32 is enough so we encourage users to select float32 as their first priority.

We will remind user to use Float type in the documentation. The support of float16 and float64 for builtin functions would be coming in the future.