Retrieve node embeddings from minibatch training for HeteroGraph

I’ve built a 2 layer GNN for link prediction on a Heterogeneous graph using minibatch training. Now I’m looking to retrieve the node embedding for specific types of nodes. I read the discussion here but it’s discussed how to retrieve node embeddings from a homogeneus graph.

From what I understood I should set a dataloader such that it iterates through all the nodes that I’m interested in loading and then run the blocks created through the sage layer of my model saving the embedding. I’m not sure how it would work on a heterogeneous graph. Whould I set a dataloader for each type of node that I’m interested in having the embeddings? Should I give the dataloader nodes ids or edge ids like I did during training?

Hi, for heterogeneous graph, just change the original embeeding to a dict of embeddings, and the key of the dict is the node type.

Hi, my current function to extract node embeddings is the following:

def get_embedding_from_graph(graph, model):
    graph_node_indices ={
        "User": graph.nodes("User"),
        "Action": graph.nodes("Action"),
        "Episode": graph.nodes("Episode")
    } 
    nodes_dataloader = dgl.dataloading.DataLoader(
        graph=graph,
        indices=graph_node_indices,
        graph_sampler=dgl.dataloading.MultiLayerFullNeighborSampler(2),
        batch_size=batch_size,
        device=device,
        use_uva=True,
        shuffle=False,
    )
    nodes_embedding = {
        "Action": torch.empty([0, num_out]), 
        "Episode": torch.empty([0, num_out]), 
        "User": torch.empty([0, num_out])
    }
    with torch.no_grad():
        for input_nodes, sampled_graph, blocks in tqdm(nodes_dataloader):
            input_features = generate_features_dict(blocks[0])
            emb = model.sage(blocks, input_features)
            nodes_embedding = {
                node_type: torch.cat(
                    (nodes_embedding[node_type], emb[node_type].to("cpu"))
                )
                for node_type in nodes_embedding
            }
    return nodes_embedding

It seems to work with GraphConv and SAGEConv, but it seems to be not working with GATv2Conv. My GAT sage layer looks like this:

class GraphLayers(nn.Module):

    
    def __init__(
        self, 
        input_nodes_features: Dict[Tuple[str, str, str], torch.Tensor],
        graph_edges: Iterable[Tuple[str, str, str]],
        hidden_features: int,
        out_features: int,
        layer_type: str = "graphconv",
        **kwargs
    ):
        super().__init__()
        
        self.conv1 = dglnn.HeteroGraphConv(
            {
                (source, edge, destination): dglnn.GATv2Conv(
                    in_feats=(
                        input_nodes_features[source].shape[1],
                        input_nodes_features[destination].shape[1],
                    ), 
                    out_feats=hidden_features,
                    activation=nn.ReLU(),
                    num_heads=kwargs.get("gat_num_heads", 4),
                    feat_drop=kwargs.get("gat_feat_drop", 0),
                    attn_drop=kwargs.get("gat_attn_drop", 0),
                    negative_slope=kwargs.get("get_negative_slope", 0.2),
                    residual=kwargs.get("gat_residual", False),
                    share_weights=kwargs.get("gat_share_weights", False),
                )
                for (source, edge, destination) in graph_edges
            },
            aggregate=kwargs.get("heteroconv_aggregate", "sum")
        )

        self.conv2 = dglnn.HeteroGraphConv(
            {
                (source, edge, destination): dglnn.GATv2Conv(
                    in_feats=hidden_features, 
                    out_feats=out_feats,
                    activation=None,
                    num_heads=kwargs.get("gat_num_heads", 4),
                    feat_drop=kwargs.get("gat_feat_drop", 0),
                    attn_drop=kwargs.get("gat_attn_drop", 0),
                    negative_slope=kwargs.get("get_negative_slope", 0.2),
                    residual=kwargs.get("gat_residual", False),
                    share_weights=kwargs.get("gat_share_weights", False),
                )
                for (source, edge, destination) in graph_edges
            },
            aggregate=kwargs.get("heteroconv_aggregate", "sum")
        )
    
    
    
    def forward(self, blocks, inputs):
        # inputs are features of nodes
        h = self.conv1(blocks[0], inputs)
        h = {edge: out_gatconv.mean(dim=1) for edge, out_gatconv in h.items()}
        h = self.conv2(blocks[1], h)
        h = {edge: out_gatconv.mean(dim=1) for edge, out_gatconv in h.items()}
        return h

I found out that the problem occurs when I set residuals to True in GATv2Conv in the second layer. I can train the network for link prediction but as soon as I try to extract the nodes embedding it chrashed, here’s the traceback:

RuntimeError                              Traceback (most recent call last)
Cell In[32], line 4
      2 for input_nodes, sampled_graph, blocks in tqdm(nodes_dataloader):
      3     input_features = generate_features_dict(blocks[0])
----> 4     emb = m2.sage(blocks, input_features)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

Cell In[2], line 60, in GraphLayers.forward(self, blocks, inputs)
     58 h = self.conv1(blocks[0], inputs)
     59 h = {edge: out_gatconv.mean(dim=1) for edge, out_gatconv in h.items()}
---> 60 h = self.conv2(blocks[1], h)
     61 h = {edge: out_gatconv.mean(dim=1) for edge, out_gatconv in h.items()}
     62 return h

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/dgl/nn/pytorch/hetero.py:198, in HeteroGraphConv.forward(self, g, inputs, mod_args, mod_kwargs)
    196         if stype not in src_inputs or dtype not in dst_inputs:
    197             continue
--> 198         dstdata = self._get_module((stype, etype, dtype))(
    199             rel_graph,
    200             (src_inputs[stype], dst_inputs[dtype]),
    201             *mod_args.get(etype, ()),
    202             **mod_kwargs.get(etype, {})
    203         )
    204         outputs[dtype].append(dstdata)
    205 else:

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/dgl/nn/pytorch/conv/gatv2conv.py:323, in GATv2Conv.forward(self, graph, feat, get_attention)
    321 # residual
    322 if self.res_fc is not None:
--> 323     resval = self.res_fc(h_dst).view(
    324         h_dst.shape[0], -1, self._out_feats
    325     )
    326     rst = rst + resval
    327 # activation

RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1, 5] because the unspecified dimension size -1 can be any value and is ambiguous

Hi @bepisworld, sorry for the late response. Seems some ops change the shape of your tensor, a fast workaround is just reshape back the tensor.

But why when I use the model for training and evaluation for link prediction everything works but when I try to extract the embeddings passign only nodes it doesn’t?

I’m not quiet understand it, could you explain more so that I can better help you?

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.