Providing reverse edge type with ShadowKHop sampler

import dgl
import torch

def get_fanouts(canonical_etypes, num_hops=1, multiplier=3):
    fanouts = [{cetype: multiplier * (i+1) for cetype in canonical_etypes}
               for i in range(num_hops - 1, -1, -1)]
    return fanouts

g = dgl.heterograph({
   ("Person", "lived in", "Country"): (torch.tensor([0, 0, 1]), torch.tensor([0, 1, 0])),
   ("Country", "had", "Person"): (torch.tensor([0, 1, 0]), torch.tensor([0, 0, 1])),
   ("Person", "knows", "Song"): (torch.tensor([0, 1, 1]), torch.tensor([0, 0, 1])),
   ("Song", "heard by", "Person"): (torch.tensor([0, 0, 1]), torch.tensor([0, 1, 1]))})

eids = {
   ("Person", "lived in", "Country"): torch.tensor([0, 1, 2]),
   ("Country", "had", "Person"): torch.tensor([0, 1]),
   ("Person", "knows", "Song"): torch.tensor([0, 1]),
   ("Song", "heard by", "Person"): torch.tensor([0, 1])
}
    
reverse_etypes = {"lived in": "had",
                  "had": "lived in",
                  "knows": "heard by",
                  "heard by": "knows"}

fanouts = get_fanouts(g.canonical_etypes, 1, 3)
node_sampler = dgl.dataloading.ShaDowKHopSampler(fanouts)
edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
    node_sampler, exclude="reverse_types", reverse_etypes=reverse_etypes)

dataloader = dgl.dataloading.DataLoader(
    g, eids, edge_sampler, batch_size=1, num_workers=0)

for i, (input_nodes, g_, blocks) in enumerate(dataloader):
    break
g_

I want to specify the reverse edge type for ShadowKHop sampler. In this example, edges are bidirectional with different names (“lived in” <> “had” and “knows” <> “heard by”). Running this code gives the following error.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Input In [1], in <cell line: 46>()
     43             break
     44     return cetype
---> 46 for i, (input_nodes, g_, blocks) in enumerate(dataloader):
     47     break
     48 g_

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/dataloading/dataloader.py:497, in _PrefetchingIter.__next__(self)
    495 def __next__(self):
    496     batch, feats, stream_event = \
--> 497         self._next_non_threaded() if not self.use_thread else self._next_threaded()
    498     batch = recursive_apply_pair(batch, feats, _assign_for)
    499     if stream_event is not None:

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/dataloading/dataloader.py:470, in _PrefetchingIter._next_non_threaded(self)
    469 def _next_non_threaded(self):
--> 470     batch = next(self.dataloader_it)
    471     batch = recursive_apply(batch, restore_parent_storage_columns, self.dataloader.graph)
    472     device = self.dataloader.device

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/torch/utils/data/dataloader.py:530, in _BaseDataLoaderIter.__next__(self)
    528 if self._sampler_iter is None:
    529     self._reset()
--> 530 data = self._next_data()
    531 self._num_yielded += 1
    532 if self._dataset_kind == _DatasetKind.Iterable and \
    533         self._IterableDataset_len_called is not None and \
    534         self._num_yielded > self._IterableDataset_len_called:

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/torch/utils/data/dataloader.py:570, in _SingleProcessDataLoaderIter._next_data(self)
    568 def _next_data(self):
    569     index = self._next_index()  # may raise StopIteration
--> 570     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    571     if self._pin_memory:
    572         data = _utils.pin_memory.pin_memory(data)

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:40, in _IterableDatasetFetcher.fetch(self, possibly_batched_index)
     38 else:
     39     data = next(self.dataset_iter)
---> 40 return self.collate_fn(data)

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/dataloading/dataloader.py:519, in CollateWrapper.__call__(self, items)
    516 if self.use_uva:
    517     # Only copy the indices to the given device if in UVA mode.
    518     items = recursive_apply(items, lambda x: x.to(self.device))
--> 519 batch = self.sample_func(self.g, items)
    520 return recursive_apply(batch, remove_parent_storage_columns, self.g)

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/dataloading/base.py:420, in EdgePredictionSampler.sample(self, g, seed_edges)
    414 seed_nodes = pair_graph.ndata[NID]
    416 exclude_eids = find_exclude_eids(
    417     g, seed_edges, exclude, self.reverse_eids, self.reverse_etypes,
    418     self.output_device)
--> 420 input_nodes, _, blocks = self.sampler.sample(g, seed_nodes, exclude_eids)
    422 if self.negative_sampler is None:
    423     return self.assign_lazy_features((input_nodes, pair_graph, blocks))

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/dataloading/shadow.py:109, in ShaDowKHopSampler.sample(self, g, seed_nodes, exclude_eids)
    107 subg = g.subgraph(seed_nodes, relabel_nodes=True, output_device=self.output_device)
    108 if exclude_eids is not None:
--> 109     subg = EidExcluder(exclude_eids)(subg)
    111 set_node_lazy_features(subg, self.prefetch_node_feats)
    112 set_edge_lazy_features(subg, self.prefetch_edge_feats)

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/sampling/utils.py:58, in EidExcluder.__call__(self, frontier)
     56 def __call__(self, frontier):
     57     parent_eids = frontier.edata[EID]
---> 58     located_eids = self._find_indices(parent_eids)
     60     if not isinstance(located_eids, Mapping):
     61         # (BarclayII) If frontier already has a EID field and located_eids is empty,
     62         # the returned graph will keep EID intact.  Otherwise, EID will change
     63         # to the mapping from the new graph to the old frontier.
     64         # So we need to test if located_eids is empty, and do the remapping ourselves.
     65         if len(located_eids) > 0:

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/sampling/utils.py:50, in EidExcluder._find_indices(self, parent_eids)
     48 if self._exclude_eids is not None:
     49     parent_eids_np = recursive_apply(parent_eids, F.zerocopy_to_numpy)
---> 50     return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids)
     51 else:
     52     assert self._filter is not None

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/sampling/utils.py:16, in _locate_eids_to_exclude(frontier_parent_eids, exclude_eids)
     11 """Find the edges whose IDs in parent graph appeared in exclude_eids.
     12 
     13 Note that both arguments are numpy arrays or numpy dicts.
     14 """
     15 func = lambda x, y: np.isin(x, y).nonzero()[0]
---> 16 result = recursive_apply_pair(frontier_parent_eids, exclude_eids, func)
     17 return recursive_apply(result, F.zerocopy_from_numpy)

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/utils/internal.py:1005, in recursive_apply_pair(data1, data2, fn, *args, **kwargs)
   1001 """Recursively apply a function to every pair of elements in two containers with the
   1002 same nested structure.
   1003 """
   1004 if isinstance(data1, Mapping) and isinstance(data2, Mapping):
-> 1005     return {
   1006         k: recursive_apply_pair(data1[k], data2[k], fn, *args, **kwargs)
   1007         for k in data1.keys()}
   1008 elif is_listlike(data1) and is_listlike(data2):
   1009     return [recursive_apply_pair(x, y, fn, *args, **kwargs) for x, y in zip(data1, data2)]

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/utils/internal.py:1006, in <dictcomp>(.0)
   1001 """Recursively apply a function to every pair of elements in two containers with the
   1002 same nested structure.
   1003 """
   1004 if isinstance(data1, Mapping) and isinstance(data2, Mapping):
   1005     return {
-> 1006         k: recursive_apply_pair(data1[k], data2[k], fn, *args, **kwargs)
   1007         for k in data1.keys()}
   1008 elif is_listlike(data1) and is_listlike(data2):
   1009     return [recursive_apply_pair(x, y, fn, *args, **kwargs) for x, y in zip(data1, data2)]

KeyError: ('Person', 'knows', 'Song')

As a side note, providing excude="self" also causes the same error, but it seems like ShadowKHop sampler already excludes the self edge.

Could you provide canonical edge types in reverse_etypes and see if it works? E.g.

reverse_etypes = {
    ("Person", "lived in", "Country"): ("Country", "had", "Person"),
    ...
}

I still get the same error.

Fix in [Bug] Fix problem with ShaDowKHopSampler working with reverse edge type exclusion by BarclayII · Pull Request #4145 · dmlc/dgl · GitHub.