Dual Embeddings for Heterogenous graph product embeddings

Screenshot 2023-10-21 at 12.23.31 PM
Trying to implement this forward pass logic explained in this paper

This is what I have come up with until now

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn


class ProductGNN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, L):
        super(ProductGNN, self).__init__()
        self.L = L
        self.linear_s = nn.Linear(in_dim, hidden_dim)
        self.linear_t = nn.Linear(in_dim, hidden_dim)
        self.out_linear_s = nn.Linear(hidden_dim, out_dim)
        self.out_linear_t = nn.Linear(hidden_dim, out_dim)

    def forward(self, g, features):
        h_s = features
        h_t = features

        for l in range(self.L):
            with g.local_scope():
                g.nodes['item'].data['h_s'] = h_s
                g.nodes['item'].data['h_t'] = h_t

                # Aggregate target representation of co-purchased out-neighbors
                g.apply_edges(func=dgl.function.copy_u('h_t', 'm'), etype='bought_together')
                g.update_all(message_func=dgl.function.copy_e('m', 'm'), reduce_func=dgl.function.sum('m', 'h_t_agg'), etype='bought_together')
                h_t_agg = g.nodes['item'].data['h_t_agg']

                # Aggregate source representation of co-viewed out-neighbors
                g.apply_edges(func=dgl.function.copy_u('h_s', 'm'), etype='viewed_together')
                g.update_all(message_func=dgl.function.copy_e('m', 'm'), reduce_func=dgl.function.sum('m', 'h_s_agg'), etype='viewed_together')
                h_s_agg = g.nodes['item'].data['h_s_agg']

                # Update source and target representations
                h_sn = F.relu(self.linear_s(h_t_agg + h_s))
                h_tn = F.relu(self.linear_t(h_s_agg + h_t))

                # Normalize the embeddings to a unit norm
                h_sn = F.normalize(h_sn)
                h_tn = F.normalize(h_tn)

        # Generate final source and target product representation
        theta_s = self.out_linear_s(h_sn)
        theta_t = self.out_linear_t(h_tn)

        return theta_s, theta_t


My doubt is if I’m right on the message passing of the source and target nodes.
Any help here would be great. Thanks

It seems that your code doesn’t handle the direction of edges carefully. In the algorithm, source features and target features are aggregated along different directions in line 4 and line 5.

Any pointers as to what could be corrected in the code provided by me.

Assuming I only deal with a homogenous graph of copurchase

from dgl.nn import SAGEConv
from torch.nn import Dropout
from dgl.nn import EdgeWeightNorm, GraphConv
import dgl.function as fn
import torch.nn.functional as F

class GraphSAGE(nn.Module):
    def __init__(self, g, in_feats, h_feats):
        super(GraphSAGE, self).__init__()
        self.dropout = Dropout(0.5)
        self.conv1_s = SAGEConv(in_feats, h_feats, 'gcn')  # Separate conv layers for source
        self.conv2_s = SAGEConv(h_feats, h_feats, 'gcn')
        self.conv1_t = SAGEConv(in_feats, h_feats, 'gcn')  # and destination nodes
        self.conv2_t = SAGEConv(h_feats, h_feats, 'gcn')

    def forward(self, g, node_features, edge_features):
        g.ndata["h_src"] = node_features
        g.ndata["h_dst"] = node_features
        g.edata["w"] = edge_features
        
        g.update_all(fn.v_mul_e('h_src', 'w', 'm1'), fn.sum('m1', 'h_s_agg'))
        g.update_all(fn.u_mul_e('h_dst', 'w', 'm2'), fn.sum('m2', 'h_t_agg'))
        
        h_s_agg = g.ndata['h_s_agg']
        h_t_agg = g.ndata['h_t_agg']
        
        h_s = self.conv1_s(g, h_s_agg)  # Use separate conv layers for source
        h_s = F.relu(h_s)
        h_s = self.conv2_s(g, h_s)
        
        h_t = self.conv1_t(g, h_t_agg)  # and destination nodes
        h_t = F.relu(h_t)
        h_t = self.conv2_t(g, h_t)
        
        return h_s, h_t

Does this look right?

  1. W^l in the algorithm should be a learnable parameter shared by source embedding and target embedding computation. You can implement it as a nn.Linear in PyTorch instead of using it as input edge features. The linear layer is applied before message passing.

  2. g.update_all(fn.v_mul_e('h_src', 'w', 'm1'), fn.sum('m1', 'h_s_agg'))
    g.update_all(fn.u_mul_e('h_dst', 'w', 'm2'), fn.sum('m2', 'h_t_agg'))
    

    In these two lines, messages are aggregated along the same direction unlike reverse directions for source embedding and target embedding in the algorithm. A convenient way for handling this would be maintaining a rev_g where all reversed edges of g are kept. Then use the message function copy_u in both directions and perform message passing on two graphs respectively.

  3. After the message passing, do normalization as in your first version. SAGEConv layers seem unnecessary.

got it, the sageconv definitely can be gotten rid off… do you think I wud need to reverse the graph even if i have edge weights from lets say, cooccurence matrix?
Does that mean I shouldn’t use edge weights at all?

Yes, the reverse graph is required for message passing in another direction. You can keep the edge feature if it contains extra information say edge weights, and this can be decided according to experimental results. Remember to apply the linear layer on source and target embeddings even if you have input edge features.

So when you say reverse graph, essentially, do you something like this?

g.update_all(fn.v_mul_e('h_src', 'w', 'm1'), fn.sum('m1', 'h_s_agg'))
g.reverse().update_all(fn.v_mul_e('h_dst', 'w', 'm2'), fn.sum('m2', 'h_t_agg'))

like this

# line 5 in the algorithm
with g.local_scope():
    g.update_all(fn.u_mul_e('h_src', 'w', 'm'), fn.sum('m', 'h_t_agg'))
    h_t = g['h_t_agg']
# line 4 in the algorithm
# keep the reversed g outside to avoid repeated construction of the reversed
# graph each time the method is invoked
with rev_g.local_scope():
    # you may need different edge weights w_ because of different normalization
    # in the reverse direction
    rev_g.update_all(fn.u_mul_e('h_dst', 'w_', 'm'), fn.sum('m', 'h_s_agg'))
    h_s = rev_g['h_s_agg']

Thanks alot. Let me try this out.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.