Trying to understand the output of `EdgeDataLoader` with negative sampling

I’m reading the example code of graphsage. I’m trying to understand the behaviour of EdgeDataLoader with negative sampling so I tried it on a testing graph. However, I’m confused about the output generated by this data loader.

The graph for tesing:

import networkx as nx
import torch as th
import dgl
import numpy as np

def build_karate_club_graph():
    src = np.array([1, 2, 2, 3, 3, 3, 4, 5, 6, 6, 6, 7, 7, 7, 7, 8, 8, 9, 10, 10,
        10, 11, 12, 12, 13, 13, 13, 13, 16, 16, 17, 17, 19, 19, 21, 21,
        25, 25, 27, 27, 27, 28, 29, 29, 30, 30, 31, 31, 31, 31, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33,
        33, 33, 33, 33, 33, 33, 33, 33, 33, 33])
    dst = np.array([0, 0, 1, 0, 1, 2, 0, 0, 0, 4, 5, 0, 1, 2, 3, 0, 2, 2, 0, 4,
        5, 0, 0, 3, 0, 1, 2, 3, 5, 6, 0, 1, 0, 1, 0, 1, 23, 24, 2, 23,
        24, 2, 23, 26, 1, 8, 0, 24, 25, 28, 2, 8, 14, 15, 18, 20, 22, 23,
        29, 30, 31, 8, 9, 13, 14, 15, 18, 19, 20, 22, 23, 26, 27, 28, 29, 30,
        31, 32])
    u = np.concatenate([src, dst])
    v = np.concatenate([dst, src])
    return dgl.DGLGraph((u, v))

def plotGraph(G):
    nx_G = G.to_networkx().to_directed()
    pos = nx.kamada_kawai_layout(nx_G)
    nx.draw(nx_G, pos, with_labels=True, node_color=[[.7, .7, .7]])

# show the graph for clear relation check
import matplotlib.pyplot as plt
plt.figure(figsize=(8,8))
    
G = build_karate_club_graph()
plotGraph(G)

Here is the negative sampler (the one in the graphsage unsupervised code):

class NegativeSampler(object):
    def __init__(self, g, k, neg_share=False):
        self.weights = g.in_degrees().float() ** 0.75
        self.k = k
        self.neg_share = neg_share

    def __call__(self, g, eids):
        src, _ = g.find_edges(eids)
        n = len(src)
        if self.neg_share and n % self.k == 0:
            dst = self.weights.multinomial(n, replacement=True)
            dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()
        else:
            dst = self.weights.multinomial(n*self.k, replacement=True)
        src = src.repeat_interleave(self.k)
        return src, dst

I created the dataloader by the following code:

# edges to compute output
train_eids = th.tensor([1, 2])
print(print("edges to compute output, src, dst are: ", G.find_edges(train_eids)))

fanouts = [1]  # List of neighbors to sample for each GNN layer, let's say just one layer and one neighbor.

sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
negative_sampler = NegativeSampler(G, 2)  # 2 negative samples per positive

# define the dataloader:
dataloader = dgl.dataloading.EdgeDataLoader(
    G,
    train_eids,
    sampler,
    exclude=None,
    negative_sampler=negative_sampler,
    batch_size=1,
    shuffle=True,
    drop_last=False,
    pin_memory=True)

I’m testing the dataloader by:

for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
    assert sum(th.eq(pos_graph.nodes(), neg_graph.nodes())) == \
        neg_graph.number_of_nodes() == \
        pos_graph.number_of_nodes()
    print("************ step-{} **********".format(step))
    print("input_nodes: ", input_nodes)
    print("pos_graph {} edges: ".format(pos_graph.number_of_edges()), pos_graph.edges())
    print("pos_graph {} nodes: ".format(pos_graph.number_of_nodes()), pos_graph.nodes())
    # neg_graph edges number == number defined in NegativeSampler()
    print("neg_graph {} edges: ".format(neg_graph.number_of_edges()), neg_graph.edges())
    print("neg_graph {} nodes: ".format(neg_graph.number_of_nodes()), neg_graph.nodes())
    for b in blocks:    
        print("\tblock nodes number: ", b.number_of_nodes())
        print("\tblock nodes: ", b.nodes("_U"))

I’ve run the last code block several time. One of the ouput is as following:

************ step-0 **********
input_nodes:  tensor([ 2,  1, 21, 26, 32, 19, 29])
pos_graph 1 edges:  (tensor([0]), tensor([1]))
pos_graph 4 nodes:  tensor([0, 1, 2, 3])
neg_graph 2 edges:  (tensor([0, 0]), tensor([2, 3]))
neg_graph 4 nodes:  tensor([0, 1, 2, 3])
	block nodes number:  11
	block nodes:  tensor([0, 1, 2, 3, 4, 5, 6])
************ step-1 **********
input_nodes:  tensor([ 2,  0,  6, 24, 28, 17,  5, 27])
pos_graph 1 edges:  (tensor([0]), tensor([1]))
pos_graph 4 nodes:  tensor([0, 1, 2, 3])
neg_graph 2 edges:  (tensor([0, 0]), tensor([2, 3]))
neg_graph 4 nodes:  tensor([0, 1, 2, 3])
	block nodes number:  12
	block nodes:  tensor([0, 1, 2, 3, 4, 5, 6, 7])

There are some results that I’m confused:

  • From which graph are the pos_graph and neg_graph sampled from? The G or the blocks in each step?
  • Why the pos_graph.edges() always be 0, 1? The edges between 0, 1 is surely not the only positive edge in the given block or the graph.
  • The nodes in neg_graph is always the same as pos_graph, does this means the input parameter eids in NegativeSampler's call function is the sampled positive eids? And what about the g parameter, is it the entire origin graph G? Or just sampled blocks?
  • For the sampled blocks, why the number of nodes in the b.nodes("_U") does not equal to b.number_of_nodes()?

Could somebody give me some explanation? Thanks for your attention!

The nodes and edges of all sampled graphs are relabeled, meaning that g.nodes() and g.edges() alone no longer reflects the node and edge IDs of the original graph.

The correct way to check the node/edge IDs sampled in pos_graph, neg_graph and blocks is this:

for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
    assert sum(th.eq(pos_graph.nodes(), neg_graph.nodes())) == \
        neg_graph.number_of_nodes() == \
        pos_graph.number_of_nodes()
    print("************ step-{} **********".format(step))
    print("input_nodes: ", input_nodes)
    print("pos_graph {} edges: ".format(pos_graph.number_of_edges()), pos_graph.edata[dgl.EID])
    print("pos_graph {} nodes: ".format(pos_graph.number_of_nodes()), pos_graph.ndata[dgl.NID])
    # neg_graph edges number == number defined in NegativeSampler()
    print("neg_graph {} edges: ".format(neg_graph.number_of_edges()), neg_graph.edata[dgl.EID])
    print("neg_graph {} nodes: ".format(neg_graph.number_of_nodes()), neg_graph.ndata[dgl.NID])
    for b in blocks:    
        print("\tblock nodes number: ", b.number_of_src_nodes(), b.number_of_dst_nodes())
        print("\tblock input nodes: ", b.srcdata[dgl.NID])
        print("\tblock output nodes: ", b.dstdata[dgl.NID])
        print("\tblock edges: ", b.edata[dgl.EID])

For your individual questions:

They are sampled from the original graph G. The blocks indicate how message propagate for each layer to compute the representations of the nodes sampled in pos_graph and neg_graph.

That’s because the nodes in pos_graph are relabeled. You can recover the corresponding node IDs in the original graph by pos_graph.ndata[dgl.NID].

Correct.

That’s the original graph g.

That’s because b is actually a bipartite graph. Currently we have a bug handling calls like b.nodes("_U") which only shows the nodes on the input side, whereas b.number_of_nodes() counts the nodes from both input side and output side. You can check the number of input nodes and output nodes via b.number_of_src_nodes() and b.number_of_dst_nodes().

1 Like

Thanks for the detailed answer. However, when I run the code you give:

for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
    assert sum(th.eq(pos_graph.nodes(), neg_graph.nodes())) == \
        neg_graph.number_of_nodes() == \
        pos_graph.number_of_nodes()
    print("************ step-{} **********".format(step))
    print("input_nodes: ", input_nodes)
    print("pos_graph {} edges: ".format(pos_graph.number_of_edges()), pos_graph.edata[dgl.EID])
    print("pos_graph {} nodes: ".format(pos_graph.number_of_nodes()), pos_graph.ndata[dgl.NID])
    # neg_graph edges number == number defined in NegativeSampler()
    print("neg_graph {} edges: ".format(neg_graph.number_of_edges()), neg_graph.edata[dgl.EID])
    print("neg_graph {} nodes: ".format(neg_graph.number_of_nodes()), neg_graph.ndata[dgl.NID])
    for b in blocks:    
        print("\tblock nodes number: ", b.number_of_src_nodes(), b.number_of_dst_nodes())
        print("\tblock input nodes: ", b.srcdata[dgl.NID])
        print("\tblock output nodes: ", b.dstdata[dgl.NID])
        print("\tblock edges: ", b.edata[dgl.EID])

I got:

************ step-0 **********
input_nodes:  tensor([ 2,  1, 33,  7, 28, 21])
pos_graph 1 edges:  tensor([2])
pos_graph 4 nodes:  tensor([ 2,  1, 33,  7])
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-52-6cbf1bc241aa> in <module>
      8     print("pos_graph {} nodes: ".format(pos_graph.number_of_nodes()), pos_graph.ndata[dgl.NID])
      9     # neg_graph edges number == number defined in NegativeSampler()
---> 10     print("neg_graph {} edges: ".format(neg_graph.number_of_edges()), neg_graph.edata[dgl.EID])
     11     print("neg_graph {} nodes: ".format(neg_graph.number_of_nodes()), neg_graph.ndata[dgl.NID])
     12     for b in blocks:

~/miniconda3/lib/python3.8/site-packages/dgl/view.py in __getitem__(self, key)
    180             return ret
    181         else:
--> 182             return self._graph._get_e_repr(self._etid, self._edges)[key]
    183 
    184     def __setitem__(self, key, val):

KeyError: '_ID'

It seems that the neg_graph doesn’t have the value for dgl.EID in its edata_schemes while pos_graph has. I printed out the pos_graph and neg_graph and I got:

>>> print(pos_graph)
Graph(num_nodes=4, num_edges=1,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
>>> print(neg_graph)
Graph(num_nodes=4, num_edges=2,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})

So why the neg_graph doesn’t have that scheme?

Whoops, my bad. neg_graph doesn’t suppose to have edge IDs in the original graph because the edges do not exist in the original graph (hence “negative”).

1 Like

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.