Extract a subgraph from a given graph, source nodes and dest nodes

I have a graph g, a destination node set dst\_nodes, and a source node set src\_nodes. I want to find all edges that start from nodes in src\_nodes and point to nodes in dst\_nodes, and use these edges to construct a DGLBlock. In other words, I want to extract a subgraph (as DGLBlock), in which the edges are the edges in g such that u -> v, u ∈ src\_nodes, v ∈ dst\_nodes.

How can I do it efficiently considering the src\_nodes and dst\_nodes sets could be large?

1 Like

Hi, we are discussing how to do in efficiently. Meanwhile, we are interested in what’s your usecase of it. If it could be a common requirement, maybe it is better to create a shared utility.

If your graph is not large, here is a workaround:

g = ...   # your graph
src_nodes = ...   # src node IDs
dst_nodes = ...   # dst node IDs
src_flag = torch.zeros(g.num_nodes())
src_flag[src_nodes] = 1
dst_flag = torch.zeros(g.num_nodes())
dst_flag[dst_nodes] = 1
g.ndata['src_flag'] = src_flag
g.ndata['dst_flag'] = dst_flag
# The following code marks all the edges u->v that has u in src_nodes and v in dst_nodes as 1.
g.apply_edges(lambda edges: {'flag' : edges.src['src_flag'] * edges.dst['dst_flag']})
edge_flag = g.edata['flag'].bool()
subg = g.edge_subgraph(edge_flag)

Note that this is a general solution for getting subgraphs according to node flags, e.g., by designing more complex edge functions.

Another workaround is to use g.in_subgraph(dst_nodes).out_subgraph(src_nodes). A minor thing is it does not remove isolated nodes. To remove nodes other than src_nodes and dst_nodes, you can further apply .subgraph(unique(src_nodes, dst_nodes)).

or remove isolated nodes via dgl.compact_graphs(g)?

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.