Lack of reproducibility in PinSAGE sampler

Hi!
I am using the PinSAGE sampler to build my frontiers for each layer of convolution. However, I seem to have found an issue of reproducibility. Despite the fact that I set dgl.seed(47), it seems two versions of the frontier generated from the same seeds yield a different subgraph. Please let me know if I am somehow mistaken. Thank you in advance!

import torch 
import dgl 

#set seed
dgl.seed(47)

#generate synthetic graph
num_playlists = 110
num_tracks = 10_000
num_edges = 12_000
src = torch.randint(0, 100, (num_edges, ))
dst = torch.randint(0, 10_000, (num_edges, ))
data_dict = {
     ('playlist', 'contains', 'track'): (src, dst),
     ('track', 'contained_by', 'playlist'): (dst, src),
 }
g = dgl.heterograph(data_dict)
g 
>> Graph(num_nodes={'playlist': 100, 'track': 10000},
      num_edges={('playlist', 'contains', 'track'): 12000, ('track', 'contained_by', 'playlist'): 12000},
      metagraph=[('playlist', 'track', 'contains'), ('track', 'playlist', 'contained_by')])

#set hyperparams
n_params = {'RANDOM_WALK_LENGTH': 2,
 'RANDOM_WALK_RESTART_PROB': 0.5,
 'NUM_RANDOM_WALKS': 10,
 'NUM_NEIGHBORS': 2}
seeds = torch.randint(0, 10_000, (32,))

#helper function for transferring NIDs to subgraph
def compact_and_copy(frontier, seeds):
    block = dgl.to_block(frontier, seeds)
    for col, data in frontier.edata.items():
        if col == dgl.EID:
            continue
        block.edata[col] = data[block.edata[dgl.EID]]
    return block

#initialize sampler 
sampler = dgl.sampling.PinSAGESampler(g, 'track', 'playlist', n_params["RANDOM_WALK_LENGTH"],
                                        n_params["RANDOM_WALK_RESTART_PROB"],
                                        n_params["NUM_RANDOM_WALKS"], 
                                        n_params["NUM_NEIGHBORS"])

#generate frontiers
frontier1 = sampler(seeds)
block1 = compact_and_copy(frontier1, seeds)
block1.ndata[dgl.NID]['track']
>>> tensor([  97, 2227, 1656, 4462, 7843, 7091, 3473, 6914, 2797, 4342, 9333, 6093,
        5736, 7377, 1065, 2359, 7149, 3757, 7067, 6887, 5156, 3322,   30,  607,
        5926, 7596,  117, 6910, 3329, 4929, 6566, 6609, 9554, 9222, 8764, 8195,
        9532, 8726, 5486, 4103, 9586, 8113, 7547, 9555, 9432, 8848, 9264, 8633,
        9862, 9593, 9757, 9262, 6972, 2441, 9673, 9356, 8859, 8181, 9292, 8135,
        9321, 2960, 9217, 7991, 9392, 8555, 1976, 9544, 9944, 9737])

frontier2 = sampler(seeds)
block2 = compact_and_copy(frontier2, seeds)
block2.ndata[dgl.NID]['track']
>>> tensor([  97, 2227, 1656, 4462, 7843, 7091, 3473, 6914, 2797, 4342, 9333, 6093,
        5736, 7377, 1065, 2359, 7149, 3757, 7067, 6887, 5156, 3322,   30,  607,
        5926, 7596,  117, 6910, 3329, 4929, 6566, 6609, 5340, 9554, 9297, 9136,
        9866, 9532, 8070, 7530, 5780, 8126, 9694, 9496, 9581, 9189, 9841, 9691,
        9573, 9406, 9262, 1253, 2985, 9321, 9870, 8664, 9291, 9286, 9639, 9578,
        8721, 8474, 9186, 8448, 9964, 9201, 9894, 9348, 8277, 6649])

block1.ndata[dgl.NID]['track'] == block2.ndata[dgl.NID]['track']
>>> tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

(block1.ndata[dgl.NID]['track'] == block2.ndata[dgl.NID]['track']).all()
>>> tensor(False) 

As you can see here, despite the fact that the seeds didn’t change, the two frontiers are not exactly equal. This causes problems with reproducibility (and leads to instability in final performance results). Perhaps there is a way to set the RNG of the sampler?

Thank you!

It seems that you did not reset the seed between the sampler calls. So essentially the random number generator state will still be different across calls. Could you try resetting the seed between the sampler calls?

PinSAGE uses random_walk which is multi-threaded, which also contributes to non-determinism. You could also try setting the environment variable OMP_NUM_THREADS=1.

Follow up on your advice @BarclayII :
I added the dgl.seed() before the second call to sample() and it worked. The os.env changes were not necessary and, when used alone, didn’t fix the issue.
For other people’s ease, I’ll put the exact code block here.

import torch 
import dgl

num_playlists = 110
num_tracks = 10_000
num_edges = 12_000
src = torch.randint(0, 100, (num_edges, ))
dst = torch.randint(0, 10_000, (num_edges, ))
data_dict = {
     ('playlist', 'contains', 'track'): (src, dst),
     ('track', 'contained_by', 'playlist'): (dst, src),
 }
g = dgl.heterograph(data_dict)

n_params = {'RANDOM_WALK_LENGTH': 2,
 'RANDOM_WALK_RESTART_PROB': 0.5,
 'NUM_RANDOM_WALKS': 10,
 'NUM_NEIGHBORS': 2}
seeds = torch.randint(0, 10_000, (32,))

def compact_and_copy(frontier, seeds):
    block = dgl.to_block(frontier, seeds)
    for col, data in frontier.edata.items():
        if col == dgl.EID:
            continue
        block.edata[col] = data[block.edata[dgl.EID]]
    return block

dgl.seed(47)
sampler = dgl.sampling.PinSAGESampler(g, 'track', 'playlist', n_params["RANDOM_WALK_LENGTH"],
                                        n_params["RANDOM_WALK_RESTART_PROB"],
                                        n_params["NUM_RANDOM_WALKS"], 
                                        n_params["NUM_NEIGHBORS"])
frontier1 = sampler(seeds)
block1 = compact_and_copy(frontier1, seeds)
block1.ndata[dgl.NID]['track']
>>> tensor([3528, 1600, 4994, 2139, 5830, 3582, 8941, 9729, 9699, 4427, 8430, 9435,
        5141, 8005,  705, 4810, 6071, 7196, 3497, 8730,  574, 5666, 2404, 8546,
        6621, 1123, 6664,  708, 9757, 6052, 1680, 2151, 9989, 8108, 8703,  913,
        9382, 9697, 9273, 6241, 9101, 9260, 8670, 3230, 9189, 2345, 9624, 9876,
        9391, 9192, 8586, 7053, 8552, 8347, 5474, 8449, 7119, 4158, 2120, 6479,
        5673, 9236, 9764, 9871, 9076, 9854, 6972, 9765, 9343, 8627, 6871, 2654,
        9668, 1806, 9921])

dgl.seed(47)
frontier2 = sampler(seeds)
block2 = compact_and_copy(frontier2, seeds)
block2.ndata[dgl.NID]['track']
>>> tensor([3528, 1600, 4994, 2139, 5830, 3582, 8941, 9729, 9699, 4427, 8430, 9435,
        5141, 8005,  705, 4810, 6071, 7196, 3497, 8730,  574, 5666, 2404, 8546,
        6621, 1123, 6664,  708, 9757, 6052, 1680, 2151, 9989, 8108, 8703,  913,
        9382, 9697, 9273, 6241, 9101, 9260, 8670, 3230, 9189, 2345, 9624, 9876,
        9391, 9192, 8586, 7053, 8552, 8347, 5474, 8449, 7119, 4158, 2120, 6479,
        5673, 9236, 9764, 9871, 9076, 9854, 6972, 9765, 9343, 8627, 6871, 2654,
        9668, 1806, 9921])

(block1.ndata[dgl.NID]['track'] == block2.ndata[dgl.NID]['track']).all()
>>> tensor(True)

In terms of using this inside the PinSage Neighborhood sampler (see the GitHub repo for the example), does this mean that I need to add dgl.seed(self.seed) (to set it each time the sampler is called)?

e.g.

class NeighborSampler(object):
    def __init__(self, g, user_type, item_type, random_walk_length, random_walk_restart_prob,
                 num_random_walks, num_neighbors, num_layers):
        self.g = g
        self.seed = cfg.RUN.SEED
        self.user_type = user_type
        self.item_type = item_type
        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
        self.samplers = [
            dgl.sampling.PinSAGESampler(g, item_type, user_type, random_walk_length,
                                        random_walk_restart_prob, num_random_walks, num_neighbors)
            for _ in range(num_layers)]

    def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None, remove_self_loops=False ):
        blocks = []
        for sampler in self.samplers:
            dgl.seed(self.seed)
            frontier = sampler(seeds)
            if heads is not None:
                eids = frontier.edge_ids(torch.cat([heads, heads]), torch.cat([tails, neg_tails]), return_uv=True)[2]
                if len(eids) > 0:
                    old_frontier = frontier
                    frontier = dgl.remove_edges(old_frontier, eids)
            block = compact_and_copy(frontier, seeds)
            seeds = block.srcdata[dgl.NID]
            blocks.insert(0, block)
        return blocks 

Another follow up. Actually, the issue doesn’t seem to be fixed.

While re-initializing dgl.seed() within the same run will produce the same results (as shown above), running the function twice, separately will produce different results (as shown below). That is, two runs will produce different subgraphs. I think this is kind of problematic since, not only are the results not reproducible, but also we cannot compare between the performance results of two runs, since this could be attributed to the neighbour selection? Please let me know if I am mistaken.

e.g.

import dgl
import torch 
import sys 
import pickle

#Initialize graph 
num_playlists = 110
num_tracks = 10_000
num_edges = 12_000
src = torch.randint(0, 100, (num_edges, ))
dst = torch.randint(0, 10_000, (num_edges, ))
data_dict = {
     ('playlist', 'contains', 'track'): (src, dst),
     ('track', 'contained_by', 'playlist'): (dst, src),
 }
g = dgl.heterograph(data_dict)

#Generate seeds
seeds = torch.tensor([4188, 1259, 5869, 5113,  391, 9271,  702, 2867, 6661, 8017, 2573,  578,
         751, 3074, 3552])
#Build Sampler 
def compact_and_copy(frontier, seeds):
    block = dgl.to_block(frontier, seeds)
    for col, data in frontier.edata.items():
        if col == dgl.EID:
            continue
        block.edata[col] = data[block.edata[dgl.EID]]
    return block


class NeighborSampler(object):
    def __init__(self, g, user_type, item_type, random_walk_length, random_walk_restart_prob,
                 num_random_walks, num_neighbors):
        self.g = g
        self.seed = 47
        self.user_type = user_type
        self.item_type = item_type
        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
        self.sampler = dgl.sampling.PinSAGESampler(g, item_type, user_type, random_walk_length,
                                        random_walk_restart_prob, num_random_walks, num_neighbors)
            

    def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None, remove_self_loops=False ):
        blocks = []
        dgl.seed(self.seed)
        frontier = self.sampler(seeds)
        if heads is not None:
            eids = frontier.edge_ids(torch.cat([heads, heads]), torch.cat([tails, neg_tails]), return_uv=True)[2]
            if len(eids) > 0:
                old_frontier = frontier
                frontier = dgl.remove_edges(old_frontier, eids)
        block = compact_and_copy(frontier, seeds)
        seeds = block.srcdata[dgl.NID]
        blocks.insert(0, block)
        return blocks

sampler1 = NeighborSampler(g, 'playlist', 'track', 2, 0.5, 10, 2)
b1 = sampler1.sample_blocks(seeds)
pickle.dump(b1, open('./b{}.pkl'.format(sys.argv[-1]), "wb"))

where the blocks data is:

b1
>>> (tensor([4188, 1259, 5869, 5113,  391, 9271,  702, 2867, 6661, 8017, 2573,  578,
          751, 3074, 3552, 5384, 8382, 9445, 9384, 9412, 9133, 9578, 9282, 8813,
         8723, 8144, 1082, 2430, 9023, 8392, 8166, 4597, 9928, 9795, 9417]),
b2
>>>tensor([4188, 1259, 5869, 5113,  391, 9271,  702, 2867, 6661, 8017, 2573,  578,
          751, 3074, 3552, 3625, 9748, 8235, 9870, 9913, 9864, 4262, 8986, 8491,
         4922, 8832, 8193, 8929, 8848, 7117, 9539, 5346, 5071, 7599, 9778]))

and you can see that after the first 15 seed ids (which are identical), the others are not.
Thank you very much

Did you fix the other seeds? Typically you also need to fix numpy and PyTorch seeds.

Yes, strangely, adding seed settings into the code strangely alternates between two of the exact same outputs (see code and output below). Is it possible that I’m making some sort of mistake in the initialization? Thank you in advance for your help :slight_smile:

Output:

RUN1:tensor([4188, 1259, 5869, 5113,  391, 9271,  702, 2867, 6661, 8017, 2573,  578,
         751, 3074, 3552, 3824, 9687, 9401, 9002, 3434, 8508, 9477, 9293, 3734,
        9481, 8619, 9660, 7915, 8455, 9675, 8702, 8183, 7977, 9955, 9829, 9933,
        9369, 9916, 7646, 6604, 9768])

RUN2:tensor([4188, 1259, 5869, 5113,  391, 9271,  702, 2867, 6661, 8017, 2573,  578,
         751, 3074, 3552, 7509, 7862, 1337, 9461, 9995, 9720, 9200, 9022, 8714,
        8189,  118, 8358, 8148, 8138, 7825, 1484, 9732, 8846, 9861, 9708, 8611,
        7735, 8466, 1990, 5961, 8639])

RUN3:tensor([4188, 1259, 5869, 5113,  391, 9271,  702, 2867, 6661, 8017, 2573,  578,
         751, 3074, 3552, 7509, 7862, 1337, 9461, 9995, 9720, 9200, 9022, 8714,
        8189,  118, 8358, 8148, 8138, 7825, 1484, 9732, 8846, 9861, 9708, 8611,
        7735, 8466, 1990, 5961, 8639])

RUN4:tensor([4188, 1259, 5869, 5113,  391, 9271,  702, 2867, 6661, 8017, 2573,  578,
         751, 3074, 3552, 3824, 9687, 9401, 9002, 3434, 8508, 9477, 9293, 3734,
        9481, 8619, 9660, 7915, 8455, 9675, 8702, 8183, 7977, 9955, 9829, 9933,
        9369, 9916, 7646, 6604, 9768])

Code:

import dgl
import torch 
import sys 
import pickle
import numpy as np 
import random 

#Fixing all seeds 
dgl.seed(47)
torch.manual_seed(47)
torch.cuda.manual_seed_all(47)   
random.seed(47)
torch.cuda.manual_seed(47)
np.random.seed(47)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#Initialize graph 
num_playlists = 110
num_tracks = 10_000
num_edges = 12_000
src = torch.randint(0, 100, (num_edges, ))
dst = torch.randint(0, 10_000, (num_edges, ))
data_dict = {
     ('playlist', 'contains', 'track'): (src, dst),
     ('track', 'contained_by', 'playlist'): (dst, src),
 }
g = dgl.heterograph(data_dict)

#Generate seed nodes 
seeds = torch.tensor([4188, 1259, 5869, 5113,  391, 9271,  702, 2867, 6661, 8017, 2573,  578,
         751, 3074, 3552])

#Build Sampler 
def compact_and_copy(frontier, seeds):
    block = dgl.to_block(frontier, seeds)
    for col, data in frontier.edata.items():
        if col == dgl.EID:
            continue
        block.edata[col] = data[block.edata[dgl.EID]]
    return block

class NeighborSampler(object):
    def __init__(self, g, user_type, item_type, random_walk_length, random_walk_restart_prob,
                 num_random_walks, num_neighbors):
        self.g = g
        self.seed = 47
        self.user_type = user_type
        self.item_type = item_type
        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
        self.sampler = dgl.sampling.PinSAGESampler(g, item_type, user_type, random_walk_length,
                                        random_walk_restart_prob, num_random_walks, num_neighbors)
            

    def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None, remove_self_loops=False ):
        blocks = []
        dgl.seed(self.seed)
        frontier = self.sampler(seeds)
        if heads is not None:
            eids = frontier.edge_ids(torch.cat([heads, heads]), torch.cat([tails, neg_tails]), return_uv=True)[2]
            if len(eids) > 0:
                old_frontier = frontier
                frontier = dgl.remove_edges(old_frontier, eids)
        block = compact_and_copy(frontier, seeds)
        seeds = block.srcdata[dgl.NID]
        blocks.insert(0, block)
        return blocks

sampler1 = NeighborSampler(g, 'playlist', 'track', 2, 0.5, 10, 2)
b1 = sampler1.sample_blocks(seeds)
print("RUN{}:{}".format(sys.argv[-1], b1[0].ndata[dgl.NID]['track'])) 

Did setting OMP_NUM_THREADS=1 help in this case?

No. I tried this as well using:

import os 
os.environ["OMP_NUM_THREADS"] = "1"
OMP_NUM_THREADS=1

and it still switched off between the two outputs. Is this what you had in mind?

Also, just to see how frequently the two versions oscillate I launched 100 runs. 42/100 had the first output version and the remaining 58 had the other output version.

It turns out that setting the number of threads needs to be done before the import of numpy (see this stack overflow).

The following code snippet will produce deterministic sampling:

import os 
#Remove multi-threading 
os.environ["OMP_NUM_THREADS"] = "1"
OMP_NUM_THREADS=1

import dgl
import torch 
import sys 
import pickle
import numpy as np 
import random 


#Fixing all seeds 
dgl.seed(47)
np.random.seed(47)
torch.manual_seed(47)
torch.cuda.manual_seed_all(47)
torch.cuda.manual_seed(47)   
random.seed(47)
torch.backends.cudnn.enabled=False
torch.backends.cudnn.deterministic=True


#Initialize graph 
num_playlists = 110
num_tracks = 10_000
num_edges = 12_000
src = torch.randint(0, 100, (num_edges, ))
dst = torch.randint(0, 10_000, (num_edges, ))
data_dict = {
     ('playlist', 'contains', 'track'): (src, dst),
     ('track', 'contained_by', 'playlist'): (dst, src),
 }
g = dgl.heterograph(data_dict)

#Generate seed nodes 
seeds = torch.tensor([4188, 1259, 5869, 5113,  391, 9271,  702, 2867, 6661, 8017, 2573,  578,
         751, 3074, 3552])

#Build Sampler 
def compact_and_copy(frontier, seeds):
    block = dgl.to_block(frontier, seeds)
    for col, data in frontier.edata.items():
        if col == dgl.EID:
            continue
        block.edata[col] = data[block.edata[dgl.EID]]
    return block

class NeighborSampler(object):
    def __init__(self, g, user_type, item_type, random_walk_length, random_walk_restart_prob,
                 num_random_walks, num_neighbors):
        self.g = g
        self.seed = 47
        self.user_type = user_type
        self.item_type = item_type
        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
        self.sampler = dgl.sampling.PinSAGESampler(g, item_type, user_type, random_walk_length,
                                        random_walk_restart_prob, num_random_walks, num_neighbors)
            

    def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None, remove_self_loops=False ):
        blocks = []
        dgl.seed(self.seed)
        frontier = self.sampler(seeds)
        if heads is not None:
            eids = frontier.edge_ids(torch.cat([heads, heads]), torch.cat([tails, neg_tails]), return_uv=True)[2]
            if len(eids) > 0:
                old_frontier = frontier
                frontier = dgl.remove_edges(old_frontier, eids)
        block = compact_and_copy(frontier, seeds)
        seeds = block.srcdata[dgl.NID]
        blocks.insert(0, block)
        return blocks

sampler1 = NeighborSampler(g, 'playlist', 'track', 2, 0.5, 10, 2)
b1 = sampler1.sample_blocks(seeds)
nbrs = b1[0].ndata[dgl.NID]['track']

Output:

0 INCONSISTENCIES OUT OF 100 RUNS
RUNS details tensor([4188, 1259, 5869, 5113,  391, 9271,  702, 2867, 6661, 8017, 2573,  578,
         751, 3074, 3552, 7509, 7862, 1337, 9461, 9995, 9720, 9200, 9022, 8714,
        8189,  118, 8358, 8138, 7991, 7194, 9620, 8888, 7596, 9955, 7864, 9475,
        8945, 1442, 8414, 2364, 9987])

Thank you for your help!

1 Like

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.