Copy of a DGLGraph instance to speed up process

Hello everyone,

Firstly I will try to briefly explain what I am trying to do. In my application I am creating a graph of around 300 nodes and 2000 edges (that we will call the base graph) and then it is merged with a different graph (that we will call the variable graph) and altogether is imputed into the GNN layers.
As the base graph is always the same (same nodes, features and edges) I want to create it just once and then do a copy that is merged with the variable graph in each case. Aiming to speed up the process since the base graph doesn’t have to be created in each instance.
This is what I’ve tried so far:

class VariableGraph(DGLGraph):

    @staticmethod
    def generate_base_graph():
        [...]
        base_graph = dgl.DGLGraph()
        base_graph.add_nodes(...)
        base_graph.add_edges(...)
        base_graph.node['h'] = ...
        [...]

        return base_graph

    BASEGRAPGH= generate_class_variable.__func__()

    def __init__(self, data, alt, mode='train', debug=False):
        super(VariableGraph, self).__init__(copy.deepcopy(self.BASEGRAPGH))
        [...]
        # logic to merge base graph with variable graph
        [...]

So I create the base graph as a class variable and pass it in the super method of the VariableGraph class that is inhirited from DGLGraph class. I thought that in this way the base graph is only created once and a copy of it is pased each time the VariableGraph is instantiated. However, I get the following error:

  File "/home/daniel/.local/lib/python3.8/site-packages/dgl/heterograph.py", line 68, in __init__
     dgl._ffi.base.DGLError: The input is already a DGLGraph. No need to create it again

I hope you undersand what I am trying to do here. If you have any doubt don’t hesitate to ask me. Maybe there is a simpler way of doing it. I am looking forward to your answers.

Greetings,
Daniel.

  1. File “/home/daniel/.local/lib/python3.8/site-packages/dgl/heterograph.py”, line 68, in init
    dgl._ffi.base.DGLError: The input is already a DGLGraph. No need to create it again

    This error is due to passing a DGLGraph to the __init__ function of DGLGraph, which is forbidden. DGL only allows create graphs with the following six functions – dgl.graph, dgl.from_scipy, dgl.from_networkx, dgl.bipartite_from_scipy, dgl.bipartite_from_networkx, and dgl.heterograph.

    In particular, you triggered the error with the line of code below

    def __init__(self, data, alt, mode='train', debug=False):
        super(VariableGraph, self).__init__(copy.deepcopy(self.BASEGRAPGH))
    
  2. DGL does not recommend inheriting DGLGraph, though it is not explicitly mentioned anywhere yet.

  3. If you want to make a deep copy of a graph and then add some new edges to it, simply do

    g = copy.deepcopy(base_graph)
    # new_src, new_dst are two tensors representing the source/destination 
    # nodes of new edges
    g.add_edges(new_src, new_dst)
    # Concatenate the features of the old nodes with the new nodes
    # new_nfeats is the features of new nodes added with the new edges
    g.ndata['h'] = torch.cat([g.ndata['h'], new_nfeats], dim=0)
    
  4. Meanwhile, it will be probably more efficient to perform

    # source and destination nodes of the edges in the base graph
    base_src, base_dst = ...
    # number of nodes in the base graph
    base_num_nodes = ...
    # node features in the base graph
    base_nfeats = ...
    
    # source and destination nodes of the edges in the first variable graph
    var1_src, var1_dst = ...
    # number of nodes in the first variable graph
    var1_num_nodes = ...
    # node features in the first variable graph
    var1_nfeats = ...
    g1 = dgl.graph((torch.cat([base_src, var1_src], dim=0), 
                    torch.cat([base_dst, var1_dst], dim=0)), num_nodes = base_num_nodes + var1_num_nodes)
    g1.ndata['h'] = torch.cat([base_nfeats, var1_nfeats], dim=0)
    
    # source and destination nodes of the edges in the second variable graph
    var2_src, var2_dst = ...
    # number of nodes in the second variable graph
    var2_num_nodes = ...
    # node features in the second variable graph
    var2_nfeats = ...
    g2 = dgl.graph((torch.cat([base_src, var2_src], dim=0), 
                    torch.cat([base_dst, var2_dst], dim=0)), num_nodes = base_num_nodes + var2_num_nodes)
    g2.ndata['h'] = torch.cat([base_nfeats, var2_nfeats], dim=0)
    

Thank you very much for your answer, that was really helpful.