Questions about DGL HeteroGraphs that you never dared to ask

Hey,

I thought I’d open a thread to all sorts of newbie questions related to heterographs.

To start right off: can the typed nodes of a heterograph have feature vectors of different sizes?

E.g. nodes of type A have a 4-dimensional feature vector, whereas nodes of type B have a feature vector with 25 dimensions.

3 Likes

First, thanks for starting this thread and I’d encourage threads like this :slight_smile:

Yes, the question is actually an important motivation to have heterogeneous graphs. Take a look at the example below:

import dgl
import torch

g = dgl.heterograph({
          ('user', 'follows', 'user'): [(0, 1), (1, 2)],
          ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
          ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
    })
g.nodes['user'].data['feat'] = torch.randn(3, 4)
g.nodes['game'].data['feat'] = torch.randn(2, 25)
2 Likes

What is the best way to copy node features from a NetworkX graph to a DGL HeteroGraph?

To be more specific, what is the best way to get specific node types and assign features to them?

Are heterographs the only way to have embedding a on both nodes and edges?

A NetworkX graph inherently does not have node and edge types. Instead, the nodes can have different feature dimensions. So there is currently no one-liner to copy node features from a NetworkX graph to a DGL HeteroGraph.

However, one can construct a homogeneous graph from a NetworkX graph, provided that the node labels in the NetworkX graph are consecutive integers. Furthermore, one can convert a homogeneous graph to a heterogeneous graph with dgl.to_hetero provided that the nodes have a feature named dgl.NTYPE (no quotes!), and the edges have a feature named dgl.ETYPE (again no quotes!), storing the node/edge type IDs. The heterogeneous graph returned by dgl.to_hetero stores a mapping from the nodes/edges in heterogeneous graph to the nodes/edges in original homogeneous graph.

So the best you can do is probably something like:

nxg = nx.DiGraph(...)
g = dgl.graph(nxg)
g.ndata[dgl.NTYPE] = ...    # assign node types from nxg
g.edata[dgl.ETYPE] = ...    # assign edge types from nxg
ntypes = [...]    # name of each node type ID
etypes = [...]    # name of each edge type ID
hg = dgl.to_hetero(g, ntypes, etypes)
for ntype in hg.ntypes:
    # node IDs in the original homogeneous graph (and the NetworkX graph)
    nxg_node_ids = hg.nodes[ntype].data[dgl.NID]
    hg.nodes[ntype].data['feature'] = get_node_features(nxg, nxg_node_ids)
# Same for edge types...
2 Likes

Not necessarily; homogeneous graphs can also have both node embeddings and edge embeddings. The embeddings must have the same dimensionality for all nodes or edges though.

Thank you for your help!

I must ask, however, what should be the data type and shape for dgl.NTYPE?

I’ve tried NumPy arrays and Python lists of strings with node type names, but I always get the following error:

  File "dgl/convert.py", line 590, in to_hetero
    ntype_ids = F.asnumpy(G.ndata[ntype_field])
  File "dgl/backend/pytorch/tensor.py", line 87, in asnumpy
    return input.cpu().detach().numpy()
AttributeError: 'numpy.ndarray' object has no attribute 'cpu'

I think you must put tensor instead of simple NumPy or list.

g_nx_pet.add_edges_from([(1, 2), (1, 3)])
g_nx_pet.nodes[1][‘label’] = 1
g_nx_pet.nodes[1][‘m1’] = 0
g_nx_pet.nodes[2][‘label’] = 1
g_nx_pet.nodes[2][‘m2’] = 1
g_nx_pet.nodes[3][‘label’] = 0
g_nx_pet.nodes[3][‘m3’] = 1

g_dgl_pet = dgl.graph(g_nx_pet)
g_dgl_pet.ndata[dgl.NTYPE] = torch.tensor([[1.], [0.], [1.]])
g_dgl_pet.edata[dgl.ETYPE] = torch.tensor([[1.],[1.],[1.],[1.]])
ntypes = [“tit”,“hea”]
etypes = [“connects”]

However, I got error when execute this line

hg_dgl_pet = dgl.to_hetero(g_dgl_pet, ntypes, etypes)
ValueError: object too deep for desired array

can anybody helps me with this error?

Thanks

If you are using PyTorch with DGL, you should set PyTorch tensors.

1 Like

That’s what I thought error raised by missing .cpu() attribute! Thanks, will try this out asap …

Can you provide a minimal code snippet for reproduction?

Still banging my head against the wall … I now get the same error as @qillbel.

# Create DGL graph from the NetworkX graph
g = dgl.graph(nx_g)

# Assign node and edge types
g.ndata[dgl.NTYPE] = torch.from_numpy(np.asarray([attr['kind'] for n, attr in nx_g.nodes(data=True)]))
g.edata[dgl.ETYPE] = torch.from_numpy(np.asarray([attr['kind'] for _, _, attr in nx_g.edges(data=True)]))

ntypes = torch.unique(g.ndata['_TYPE'])
etypes = torch.unique(g.edata['_TYPE'])

# Create heterograph
hg = dgl.to_hetero(g, ntypes, etypes)

Here’s the node type information assigned to g.ndata[dgl.NTYPE]:

tensor([[3],
        [3],
        [3],
        [2],
        [4],
        [2],
        [4],
        [2],
        [4],
        [4]], dtype=torch.int32)

This gives me the error

File "dgl/convert.py", line 594, in to_hetero
ntype_count = np.bincount(ntype_ids, minlength=num_ntypes)
File "<__array_function__ internals>", line 6, in bincount
ValueError: object too deep for desired array
1 Like

(Also for @qillbel)

The array assigned for dgl.NTYPE and dgl.ETYPE should be an int64 vector, that is, it must be one-dimensional. Moreover, the values should be the index of the node type names and edge type names to be passed to dgl.to_hetero.

So let’s say that you have node types and edge types as follows:

g_nx_pet = networkx.Graph([(1, 2), (1, 3)])
g_dgl_pet = dgl.graph(g_nx_pet)
ntypes = ['tit','hea']
etypes = ['connects']

Then this doesn’t work since the array is two-dimensional float:

g_dgl_pet.ndata[dgl.NTYPE] = torch.tensor([[1.], [0.], [1.]])
g_dgl_pet.edata[dgl.ETYPE] = torch.tensor([[1.],[1.],[1.],[1.]])
hg_dgl_pet = dgl.to_hetero(g_dgl_pet, ntypes, etypes)

This also doesn’t work since there’s no 2nd edge type in your edge type list (the node and edge type IDs are labeled from 0).

g_dgl_pet.ndata[dgl.NTYPE] = torch.LongTensor([1, 0, 1])
g_dgl_pet.edata[dgl.ETYPE] = torch.LongTensor([1, 1, 1, 1])
hg_dgl_pet = dgl.to_hetero(g_dgl_pet, ntypes, etypes)

This will work:

g_dgl_pet.ndata[dgl.NTYPE] = torch.LongTensor([1., 0., 1.])
g_dgl_pet.edata[dgl.ETYPE] = torch.LongTensor([0.,0.,0.,0.])
hg_dgl_pet = dgl.to_hetero(g_dgl_pet, ntypes, etypes)
hg_dgl_pet.metagraph.edges()
# OutMultiEdgeDataView([('hea', 'tit'), ('hea', 'hea'), ('tit', 'hea')])
2 Likes

Massive thanks @BarclayII, finally got the logic behind this and now it works perfectly!

Back with more questions. :slight_smile:

  1. Is it correct that dgl.nn modules cannot be used with DGLHeteroGraphs?

More generally, I would be very interested in learning how graph neural networks handle typed nodes. If someone can point me to a useful, relatively simple explanation, I would be very thankful.

  1. Is it possible to get a readout for a DGLHeteroGraph by averaging node features, if the node features are of different dimensionality?

Hi,
for the first question, (well its general version) RGCN or HAN(Hetero attention networks) are a good answers.
For the second question I have the same issue , not clear for me how can we apply readout on an heterograph or a batched heterograph

1 Like

To add to the answer from @ar795, currently DGL GNN modules support unidirectional bipartite graphs as well. Simply supply a unidirectional bipartite graph as well as a pair of feature tensors on source/destination types and you should be good.

module = dgl.nn.SAGEConv(...)
g = dgl.bipartite(..., 'user', 'clicks', 'item')
result = module(g, (user_features, item_features))

For the second problem, currently we don’t have a one-liner for batched-heterograph readout, and you may need to do that yourself. @mufeili could probably add his thought on this.

1 Like

@thiippal See if the workaround here is good for you.

1 Like

Thank you @thiippal for this thread! Perfect for my question:

Is it possible to have multiple feature matrices for one node type?

I have a heterograph consisting of lets say node type A, B and C. There are n nodes of node type A.
I would like to add a feature matrix of (n x l) and another feature matrix of (n x k) for node type A. For the other node type B with m nodes, I would also like to add multiple feature matrices, that will have different sizes. Let’s say we add a feature matrix of (m x p) and one matrix of (m x q) for node type B.

How could I implement this?

Thank you in advance for your answer!

Hi @sopkri, does the example below help?

import dgl
import torch

g = dgl.heterograph({
    ('user', 'follows', 'user'): [(0, 1), (1, 2)],
    ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
    ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
})
g.nodes['user'].data['h1'] = torch.randn(3, 1)
g.nodes['user'].data['h2'] = torch.randn(3, 2)
1 Like