Efficient Knowledge Graph Train, Validation, Test Splits

I wrote my own code to create training, validation, and testing splits from my canonical tuples after extracting them from a graph; however, during testing, I’m getting an error about some entities/relations not being recognized. I suspect this is because my test set includes some nodes that were not in the training set, which means I didn’t do a proper split.

The split I wrote created the train/valid/test sets like

all_ctups_len = len(all_ctups) #all_ctups is a list of all canonical tuples
train_split = int(all_ctups_len*train_pct) #train_pct is a float between 0 and 1
valid_split = int(train_split + all_ctups_len*valid_pct) #tvalid_pct is a float between 0 and 1
# Randomly shuffle canonical edges for sampling
random.shuffle(all_ctups)
# Get training set
train = all_ctups[:train_split]
valid = all_ctups[train_split:valid_split]
test  = all_ctups[valid_split:]

Is there a dgl built-in method for properly splitting canonical tuples into train/valid/test sets that ensures all entities and relations that are in the validation and testing sets are also in the training set? If not, I can adjust my code accordingly. Just wanted to see if there’s an efficient built-in way I wasn’t missing first.

Thanks in advance! (And if anyone has code for how they do splits, I’d love to see it!)
Alex

For what it’s worth, I addressed the problem above with the following:

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

data = pd.read_csv("data.txt", sep="\t", names=["src", "rel", "dst"])
all_ctups = np.array(data)
# Count unique destination values for splitting
uniqueValues , indicesList, occurCount = np.unique(all_ctups[:,2], return_index=True, return_counts=True)
# Get values of poorly connected destination entities (entities with >= k edges)
k = 10
long_uniqueValues  = uniqueValues[np.where(occurCount>=k,True,False)]
# Get indices of poorly connected destination enttiies (entities with >= k edges)
long_indices  = np.in1d(all_ctups[:,2], long_uniqueValues)
# Drop rows poorly connected destination entities
all_ctups = all_ctups[long_indices]
# Get length of canonical tuples after dropping poorly connected rows
all_ctups_len_after = len(all_ctups)
print("Number of canonical tuples after dropping poorly connected destination entities: ", all_ctups_len_after)

# Get train-valid vs test splits
X_train_and_val, X_test, y_train_and_val, y_test = train_test_split(
    all_ctups[:,:2],
    all_ctups[:,2],
    stratify=all_ctups[:,2],
    test_size=0.1,
    random_state=407
)

# Get train vs valid splits
X_train, X_val, y_train, y_val = train_test_split(
    X_train_and_val,
    y_train_and_val,
    stratify=y_train_and_val,
    test_size=0.1,
    random_state=407
)
# Stack arrays to recreate tuples
train = np.column_stack((X_train,y_train))
valid = np.column_stack((X_val,y_val))
test  = np.column_stack((X_test,y_test))
2 Likes

Just to leave a note that a DGL graph allows orphan node (node with no attached edge) and empty relation (relation with no actual linkage). Here are some examples:

import dgl

# Create a graph with 100 nodes but only three edges. Many are orphan nodes.
u = [1, 2, 3]
v = [4, 5, 6]
g = dgl.graph((u, v), 'V', 'E', num_nodes=100)

# Create a graph with an empty relation 'R'
hg = dgl.heterograph({
  ('V', 'E', 'V') : g,
  ('V', 'R', 'X') : ([], []),
})

Thus, to make sure that all the entities and relations in test graph are in training/validation graphs too, you can first compute how many nodes and relations in total and then create all three graphs with the same number of nodes and relation set.

However, there is no free lunch. The problem of orphan nodes is that they still occupy node feature storage. For example, the above g has 100 nodes so a ndata tensor will be of shape (100, ...). Empty relation, nevertheless, is fine because edata storage is proportional to the number of edges, which is zero in this case.

1 Like

Thanks for this note! That’s very helpful to know. In my case, my network has over 10 million nodes and over 100 million edges, so part of why I excluded poorly connected nodes is to save space/memory and to improve compute time (and in my case all nodes have at least one edge). But this is great to know for other cases in which orphan nodes may exist. Thank you again!