Questions on GAT Link prediction with Pytorch Lightning

I have looked at the dgl docs and examples on link prediction - general, stochastic and minibatch, I am still confused about implementing it with GAT in dgl and pytorch-lightning.

  1. I believe adding self-loops in GAT link prediction graph is helpful in both training and inference. Now with stochastic model, where should I add self loops? @mufeili suggested last time to have a custom sampler, but I was still confused about how to change it.
    I am thinking doing add_self_loop() on the subgraph from dgl.sampling.sample_neighbors in a custom sampler takes into account, while we keep the train_pos_g without these self-loops for getting the edge predictor scores
  2. In lightning Trainer data loaders, I did split the graph into train/test/val graph with 80/10/10 % edges, and then dgl.dataloading.DataLoader is fed these new split graphs and negative_sampler.GlobalUniform for negative samples, which I feel will make a mistake of mixing positive val/test edges as neg train edges and vice-versa. How should I proceed with lightning here?
  3. The stochastic link prediction example does take the entire graph in dataloader, which can then effectively sample negative edges from the sampler, but where exactly are validation and test edges are going to coming from?

In addition, it would be helpful if I could see the code of how the custom samplers should look like for pos and neg samples in the lightning dataloader, such that the negative edges are truly negative in the entire original graph (like GlobalUniform)
and positive edges include 5 neighbourhood edges of respective train/val/test set with self-loops always added to every node in the positive sample subgraphs.

  1. You could mostly copy NeighborSampler implementation, with a few differences:
     def sample_blocks(self, g, seed_nodes, exclude_eids=None):
         output_nodes = seed_nodes
         blocks = []
         for fanout in reversed(self.fanouts):
             frontier = g.sample_neighbors(
                 seed_nodes, fanout, edge_dir=self.edge_dir, prob=self.prob,
                 replace=self.replace, output_device=self.output_device,
                 exclude_edges=exclude_eids)
             #eid = frontier.edata[EID]    <---- remove this since the self-loop usually don't exist in the original graph
             frontier = dgl.add_self_loop(frontier)    # <---- add this
             block = to_block(frontier, seed_nodes)
             #block.edata[EID] = eid       <---- also remove this
             seed_nodes = block.srcdata[NID]
             blocks.insert(0, block)
    
         return seed_nodes, output_nodes, blocks
    
    Then you could replace NeighborSampler in your DataLoader definition with your new NeighborSampler.
  2. GlobalUniform will not return negative samples that have an edge in between. So if your graph contains the training, validation and test edges at the same time, then GlobalUniform will not sample the validation and test edges.
  3. If you are asking where to get validation and test edges, usually the validation and test edges should be provided by the dataset itself.
  1. In lightning Trainer data loaders, I did split the graph into train/test/val graph with 80/10/10 % edges, and then dgl.dataloading.DataLoader is fed these new split graphs and negative_sampler.GlobalUniform for negative samples, which I feel will make a mistake of mixing positive val/test edges as neg train edges and vice-versa. How should I proceed with lightning here?

I don’t think you should exclude true val/test edges from sampled negative edges during training as such information is not available during training.

Yeah, I realised it late, here is what I think now, wrapped in lightning dataloader:

neighbour_sampler = CustomNeighborSampler([4, 4], output_device = args.device)
negative_sampler = dgl.dataloading.negative_sampler.GlobalUniform(1)
self.train_sampler = dgl.dataloading.as_edge_prediction_sampler(
                    sampler=neighbour_sampler, negative_sampler=negative_sampler)

def train_dataloader(self):
        return dgl.dataloading.DataLoader(
                self.g, 
                # <-- entire graph with all edges on all train/val/test dataloaders
                torch.tensor(self.train_eids, device = args.device), 
                # <-- the split positive edges of corresponding train/val/test set
                self.train_sampler,
                device = self.device,
                batch_size = self.batch_size,
                shuffle = True,
                drop_last = False,
                num_workers = self.num_workers
                )
  1. Thanks
  2. yes, but in dgl docs example, the test set is not created from the input graph, I guess either the example assumed the problem is inductive and there is a separate test and val set graphs or the transductive split is already done and kept separate from the loaded graph, as the example considered the entire graph edges as the positive train input edges in dataloader.

I did the similar thing earlier from docs, which also doesn’t have edata lines,
but somehow adding this line causes errors like:

File ~/miniconda3/envs/dgl/lib/python3.10/site-packages/torch/utils/data/dataloader.py:530, in _BaseDataLoaderIter.__next__(self)
    528 if self._sampler_iter is None:
    529     self._reset()
--> 530 data = self._next_data()
    531 self._num_yielded += 1
    532 if self._dataset_kind == _DatasetKind.Iterable and \
    533         self._IterableDataset_len_called is not None and \
    534         self._num_yielded > self._IterableDataset_len_called:

File ~/miniconda3/envs/dgl/lib/python3.10/site-packages/torch/utils/data/dataloader.py:570, in _SingleProcessDataLoaderIter._next_data(self)
    568 def _next_data(self):
    569     index = self._next_index()  # may raise StopIteration
--> 570     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    571     if self._pin_memory:
    572         data = _utils.pin_memory.pin_memory(data)

File ~/miniconda3/envs/dgl/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:40, in _IterableDatasetFetcher.fetch(self, possibly_batched_index)
     38 else:
     39     data = next(self.dataset_iter)
---> 40 return self.collate_fn(data)

File ~/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/dataloading/dataloader.py:519, in CollateWrapper.__call__(self, items)
    516 if self.use_uva:
    517     # Only copy the indices to the given device if in UVA mode.
    518     items = recursive_apply(items, lambda x: x.to(self.device))
--> 519 batch = self.sample_func(self.g, items)
    520 return recursive_apply(batch, remove_parent_storage_columns, self.g)

File ~/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/dataloading/base.py:420, in EdgePredictionSampler.sample(self, g, seed_edges)
    414 seed_nodes = pair_graph.ndata[NID]
    416 exclude_eids = find_exclude_eids(
    417     g, seed_edges, exclude, self.reverse_eids, self.reverse_etypes,
    418     self.output_device)
--> 420 input_nodes, _, blocks = self.sampler.sample(g, seed_nodes, exclude_eids)
    422 if self.negative_sampler is None:
    423     return self.assign_lazy_features((input_nodes, pair_graph, blocks))

File ~/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/dataloading/base.py:243, in BlockSampler.sample(self, g, seed_nodes, exclude_eids)
    241 def sample(self, g, seed_nodes, exclude_eids=None):     # pylint: disable=arguments-differ
    242     """Sample a list of blocks from the given seed nodes."""
--> 243     result = self.sample_blocks(g, seed_nodes, exclude_eids=exclude_eids)
    244     return self.assign_lazy_features(result)

/home/neo/notebooks/phoenix/2link_prediction/lp_nopl.ipynb Cell 7' in CustomNeighborSampler.sample_blocks(self, g, seed_nodes, exclude_eids)
     29 frontier = dgl.add_self_loop(frontier)
---> 30 block = dgl.transforms.to_block(frontier, seed_nodes)
     32 seed_nodes = block.srcdata[dgl.NID]

File ~/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/transforms/functional.py:2281, in to_block(g, dst_nodes, include_dst_in_src, src_nodes)
   2277 else:
   2278     # use an empty list to signal we need to generate it
   2279     src_node_ids_nd = []
-> 2281 new_graph_index, src_nodes_ids_nd, induced_edges_nd = _CAPI_DGLToBlock(
   2282     g._graph, dst_node_ids_nd, include_dst_in_src, src_node_ids_nd)
   2284 # The new graph duplicates the original node types to SRC and DST sets.
   2285 new_ntypes = (g.ntypes, g.ntypes)

File dgl/_ffi/_cython/./function.pxi:287, in dgl._ffi._cy3.core.FunctionBase.__call__()

File dgl/_ffi/_cython/./function.pxi:232, in dgl._ffi._cy3.core.FuncCall()

File dgl/_ffi/_cython/./base.pxi:155, in dgl._ffi._cy3.core.CALL()

DGLError: [17:02:24] /opt/dgl/src/graph/transform/to_bipartite.cc:119: Check failed: new_dst.Ptr<IdType>()[i] != -1 (-1 vs. -1) : Node 0 does not exist in `rhs_nodes`. Argument `rhs_nodes` must contain all the edge destination nodes.
Stack trace:
  [bt] (0) /home/neo/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/libdgl.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x4f) [0x7f7f15ff0cdf]
  [bt] (1) /home/neo/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/libdgl.so(+0x782c11) [0x7f7f16465c11]
  [bt] (2) /home/neo/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/libdgl.so(std::tuple<std::shared_ptr<dgl::BaseHeteroGraph>, std::vector<dgl::runtime::NDArray, std::allocator<dgl::runtime::NDArray> > > dgl::transform::ToBlock<(DLDeviceType)1, long>(std::shared_ptr<dgl::BaseHeteroGraph>, std::vector<dgl::runtime::NDArray, std::allocator<dgl::runtime::NDArray> > const&, bool, std::vector<dgl::runtime::NDArray, std::allocator<dgl::runtime::NDArray> >*)+0x3a) [0x7f7f164669ea]
  [bt] (3) /home/neo/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/libdgl.so(+0x7842c7) [0x7f7f164672c7]
  [bt] (4) /home/neo/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/libdgl.so(+0x784bc4) [0x7f7f16467bc4]
  [bt] (5) /home/neo/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/libdgl.so(DGLFuncCall+0x48) [0x7f7f16316228]
  [bt] (6) /home/neo/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/_ffi/_cy3/core.cpython-310-x86_64-linux-gnu.so(+0x16b77) [0x7f7f15aa5b77]
  [bt] (7) /home/neo/miniconda3/envs/dgl/lib/python3.10/site-packages/dgl/_ffi/_cy3/core.cpython-310-x86_64-linux-gnu.so(+0x1704b) [0x7f7f15aa604b]
  [bt] (8) /home/neo/miniconda3/envs/dgl/bin/python(_PyObject_MakeTpCall+0x15e) [0x55a6c14d752e]

I am not sure why this is occurring, with only self_loop() change, as the nodes should be same with/without self-loop unless the reordering nodes from 0 for each block in mfg is being failing somehow.

so, I notice that the modified graph’s self-loop edges are considered to find the new_dst nodes, from lines 78, 82, and 116
which I guess is then causing the fail, as all the nodes of graph now have to be in seed_nodes for the sampling function, due to at least one edge is connecting each of these nodes with self-loop addition.

Is this then considered as a bug? or am I misunderstanding the above code
Is there any other way to do the same operation?

Yeah you are right. So instead of using dgl.add_self_loop you might need

frontier = dgl.add_edges(frontier, seed_nodes, seed_nodes)
1 Like

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.