Create dataset from DGLGraphs in memory

Trying to create a dataset using the DGLGraphs I created in memory.

In the documentation, I can see that I can persist to and then load graphs from a file. Is there a way to do this in memory so that I don’t have to manage these files and system write permissions?

https://docs.dgl.ai/generated/dgl.data.utils.load_graphs.html#dgl.data.utils.load_graphs

| Method                                        | Description                             | 
|-----------------------------------------------|-----------------------------------------|
| utils.save_graphs(filename, g_list[, labels]) | Save DGLGraphs and graph labels to file |
| utils.load_graphs(filename[, idx_list])       | Load DGLGraphs from file 

Should I make a dgl.data.MY_OWN_CLASS with graphs and labels attribute? By just passing in labels and graphs to the class?

Name:    dgl
Version: 0.4.3.post2
1 Like

Similar questions below. Looks like a strong candidate for a feature:

Ah, okay. So what was not clear in the docs is that dgl.batch is used to create a th.dataset

Why would you use a non-standard dataset in the docs without demoing it from DGLGraphs first?

I will document how to do this and post it below so that the docs can be updated.

Sorry for the inconvenience. Are you trying to do a graph classification or regression task? Does this https://docs.dgl.ai/tutorials/basics/4_batch.html help?

Here are my suggestions for creating your own data set for DGL. The first consideration is the type of tasks you’d like to perform. In general, there are three: Node classification, Edge classification or Link prediction, and Graph classification. The second dimension is whether you have one graph or multiple graphs. This is also related to whether your setting is inductive or transductive.

Node classification on one graph

A very minimal example will look like this:

class NodeClassificationDataset:
    def __init__(self):
        self.graph = None   # The graph structure stored in DGLGraph.
        self.node_feat1 = None  # Node feature 1, tensor of shape (num_nodes, D1)
        self.node_feat2 = None  # Node feature 2, tensor of shape (num_nodes, D2)
        self.edge_feat1 = None  # Edge feature 1, tensor of shape (num_edges, D1)
        self.edge_feat2 = None  # Edge feature 2, tensor of shape (num_edges, D2)
        self.labels = None  # Node labels, 1-D tensor of length num_nodes
        # train/val/test split
        self.train_nids = None  # A 1-D integer tensor storing node IDs for training
        self.val_nids = None  # A 1-D integer tensor storing node IDs for validation
        self.test_nids = None  # A 1-D integer tensor storing node IDs for test
        self._load()

    def _load(self):
        # Load each class member from disk
        ...

The dataset has only a single graph, two node and two edge feature tensors, and a node label tensor. The example adopts a typical semi-supervised setting from Kipf’s GCN paper, where three arrays train_nids, val_nids and test_nids split the whole node ID space. Using boolean masks here is also fine. The _load function reads files stored on disk and initialize each of the members. It is also a good place for some pre-processing such as adding/removing edges in the graph, normalize feature values, etc.

Actual examples:

Link prediction on one graph

The overall data set design is quite close to the one for node classification. However, since the links to predict shall not appear in the graph for training (otherwise, the labels are leaked), we create two graphs: one for training and one for testing. In a transductive setting, both graphs have the same node set (i.e., no new node in the test graph).

class LinkPredictionDataset:
    def __init__(self):
        self.train_graph = None  # A graph containing only the edges for training.
        self.test_graph = None   # A graph containing only the edges for testing.
        self.node_feat1 = None  # Node feature 1, tensor of shape (num_nodes, D1)
        self.node_feat2 = None  # Node feature 2, tensor of shape (num_nodes, D2)
        self.train_edge_feat = None  # Edge feature, tensor of shape (num_train_edges, D1)
        self._load()

    def _load(self):
        # Load each class member from disk
        ...

Training only happens on the train graph so no test edges are visible. Again, the _load function is a good place for various kinds of pre-processing and data augmentation. The MovieLens dataset generally follows this design and the GCMC example shows how to use it.

Graph classification

Most data sets for graph classification contain many graphs. In such case, we recommend implement the data set as an iterable object so one can use it with data loaders. We also recommend store node and edge features directly in the graph object.

class GraphClassificationDataset:
    def __init__(self, filename):
        self.graphs = []  # A list of DGLGraph objects
        self.labels = []
        self._load(filename)

    def _load(self, filename):
        # Load graphs from disk. Save node and edge features to each graph object using
        # the `ndata` and `edata` dictionary.
        ...

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, i):
        # Get the i^th sample and label
        return self.graphs[i], self.labels[i]

In the training script, create separate dataset objects for training and testing. Use dgl.batch function in the collate function to create a larger graph with each sample as one connected component.

import torch
from torch.data.utils import DataLoader
import dgl

def collate_fn(graph_and_label_list):
    graphs, labels = list(zip(*graph_and_label_list))
    return dgl.batch(graphs), torch.tensor(labels)

def main():
    train_set = GraphClassificationDataset('train_file')
    test_set = GraphClassificationDataset('test_file')
    train_loader = DataLoader(train_set, batch_size, collate_fn)
    test_loader = DataLoader(test_set, batch_size, collate_fn)
    ...

Actual examples:

Other combinations

Apologize for the lack of a more organized data set folder. We are working on a more self-explained and consistent one.

3 Likes

Here is what I ended up doing when I learned that a dataset is just a list of tuples [(g0,l0),(g1,l1)...] passed into th.DataLoader.

It runs to completion, but not sure if everything is working as intended.

labels = [1,1,1,1,1,1,1,0,0,0]
#simplified version
graphs = [G0,G1,G2,G3,G4,G5,G6,G7,G8,G9]

# could fetch w numpy unique if i needed to
num_classes = 2
graph_count = len(graphs)
samples_train = []

for i in range(graph_count):    
    current_graph = graphs[i]
    current_label = labels[i]
    
    pair = (current_graph, current_label)
    
    samples_train.append(pair)
import random
random.shuffle(samples_train)
train_split_pct = 0.3
train_split_index = round(graph_count*train_split_pct)
samples_test = samples_train[:train_split_index]
del samples_train[:train_split_index]
import dgl
import torch

def collate(samples):
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)
from dgl.nn.pytorch import GraphConv
import torch.nn as nn
import torch.nn.functional as F

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        h = g.in_degrees().view(-1, 1).float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)
import torch.optim as optim
from torch.utils.data import DataLoader

# Use PyTorch's DataLoader and the collate function
# defined before.
# data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
#                          collate_fn=collate)
dataset_train = DataLoader(
    samples_train,
    batch_size=32,
    shuffle=True,
    collate_fn=collate
)

#dataset_test = DataLoader(
#    samples_test,
#    batch_size=32,
#    shuffle=True,
#    collate_fn=collate
#)
model = Classifier(1, 256, num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.010)#0.001
model.train()

epochs = 50

epoch_losses = []

for epoch in range(epochs):
    epoch_loss = 0
    for iter, (g, l) in enumerate(dataset_train):
        prediction = model(g)
        loss = loss_func(prediction, l)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)
plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()
model.eval()
# Convert a list of tuples to two lists
# remember X is features and Y is labels.
test_X, test_Y = map(list, zip(*samples_test))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))

Hi and thanks for your post. I have question. Is it possible to train a GNN for link prediction on multiple graphs instead of just one?

Question answered here: Link prediction only in subgraph / Training a GNN on various graphs

1 Like