Help: How to Custom a Dataset for GNNs with DGL

¡Hello! I’m exploring the use of the DGL library to work with heterogeneous graphs, specifically creating my own graphs from scratch. I’ve created two small graphs, each with a single attribute per node representing the node’s degree. The graph label is defined as the sum of all node degrees.

I based myself on tutorials from the DGL documentation and used the network model provided in the DGL documentation. However, I’m experiencing some errors that I can’t fully understand. I believe these errors may be related to how I’m building and saving the graphs.

class BGPDataset(DGLDataset):
    def __init__(self, url=None, raw_dir=None, save_dir=None, force_reload=False, verbose=False):

        super(BGPDataset, self).__init__(name='BGP_dataset', url=url, raw_dir=raw_dir, save_dir=save_dir, force_reload=force_reload, verbose=verbose)
        random.seed(12)
        self.graphs = [] #todos los grafos
        self.labels = [] #tpdas laas labels de los grafos

        # global num
        self.N = 0  # total graphs number
        self.n = 0  # total nodes number
        self.m = 0  # total edges number

        # global num of classes
        self.gclasses = 1
        self.nclasses = None
        self.eclasses = None
        self.dim_nfeats = 1

    def download(self):
        pass

    def process(self):
        pass

    def __getitem__(self, idx):
        return self.graphs[idx], self.labels[idx]

    def __len__(self):
        return len(self.graphs)

    def save(self):
        pass

    def load(self):
        pass

    def has_cache(self):
        pass

    def generate_graphs(self, num_graphs=1):
      for i in range(num_graphs):
        graph_data = {

          ('AS', 'topological_link', 'AS'): (torch.tensor([random.randint(0, 3), random.randint(0, 3)]), torch.tensor([random.randint(0, 5),random.randint(0, 5)])),

          ('AS', 'sends_package', 'BGP_package'):   (torch.tensor([random.randint(0, 3),random.randint(0, 3)]), torch.tensor([random.randint(0, 5),random.randint(0, 5)])),

          ('BGP_package', 'recv_package', 'AS'): (torch.tensor([random.randint(0, 3),random.randint(0, 3)]), torch.tensor([random.randint(0, 5),random.randint(0, 5)]))

        }

        # Crear el grafo
        g = dgl.heterograph(graph_data)

        # Calculo grado nodos AS
        degres_AS = g.out_degrees(etype='topological_link') + g.in_degrees(etype='topological_link')

        # Agrego Feats a nodos
        g.nodes["AS"].data["feat"] = degres_AS
        g.nodes["BGP_package"].data["feat"] = torch.ones(len(g.nodes("BGP_package")))

        # guardo grafos y labels
        self.graphs.append(g)
        self.labels.append(torch.sum(degres_AS))



class GNN:
    def __init__(self, debug=False):
        self.debug = debug
        self.graph = None

    def load_dataset(self,name_data ,force_reload=False,make_bidirectional = True):
      elif name_data =="SyntheticDataset":
        self.dataset = SyntheticDataset()
      elif name_data == "BGPDataset":
        url = "url"
        raw_dir = "url"
        save_dir = "url"
        my_dataset = BGPDataset( url=url, raw_dir=raw_dir, save_dir=save_dir, force_reload=True, verbose=False)
        my_dataset.generate_graphs(2) #FIXME: cambiar el 2 por parametro
        self.dataset = my_dataset




class RGCN(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dgl.nn.pytorch.HeteroGraphConv({
            rel: GraphConv(in_feats, hidden_feats)
            for rel in rel_names}, aggregate='sum')

        self.conv2 = dgl.nn.pytorch.HeteroGraphConv({
            rel: GraphConv(hidden_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs is features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h


class HeteroClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
        super().__init__()
        self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        h = g.ndata['feat']
        h = self.rgcn(g, h)
        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            hg = 0
            for ntype in g.ntypes:
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            return self.classify(hg)



debug = True
gnn = GNN()
gnn.load_dataset(name_data="BGPDataset" ,make_bidirectional = False)

dim_nfeat = gnn.dataset.dim_nfeats #1
hidden_dim = 1
out_dim = gnn.dataset.gclasses #1
print(dim_nfeat,"|",hidden_dim,"|",out_dim)

dataloader = GraphDataLoader(gnn.dataset,
                             batch_size=1,
                             drop_last = False,
                             shuffle = True) 
if debug:
  print("DATASET")
  print(f"Grafo: { gnn.dataset[0]} \nlabel: {gnn.dataset.labels[0]}")
  print(f"Grafo: { gnn.dataset[1]} \nlabel: {gnn.dataset.labels[1]}")
  print("---------------------------------")

etypes = ['topological_link','sends_package','recv_package']
model = HeteroClassifier(dim_nfeat,hidden_dim,out_dim, etypes)
optimizer = torch.optim.Adam(model.parameters(),lr =0.001)
print("ENTRENAMINETO ---- LOOP --------------------------")
for epoch in range(1):
  for batched_graph, labels in dataloader:
    model.train()
    pred = model(batched_graph)
    loss = F.cross_entropy(pred,labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if debug:
      print(f"BATCH GRAPH: {batched_graph} \nATTR: {batched_graph.ndata} \nLABELS: {labels}")
      print("---------------")
  if debug:
    print(f"EPOCH {epoch}")

I managed to fix the issue I was encountering. It turned out to be a problem with the dimensions I was assigning to the nodes. I changed it as follows

# Adding features to nodes
g.nodes["AS"].data["feat"] = degrees_AS.unsqueeze(1)
g.nodes["BGP_package"].data["feat"] = torch.ones(len(g.nodes["BGP_package"]), 1)
1 Like