Trying to implement this forward pass logic explained in this paper
This is what I have come up with until now
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
class ProductGNN(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, L):
super(ProductGNN, self).__init__()
self.L = L
self.linear_s = nn.Linear(in_dim, hidden_dim)
self.linear_t = nn.Linear(in_dim, hidden_dim)
self.out_linear_s = nn.Linear(hidden_dim, out_dim)
self.out_linear_t = nn.Linear(hidden_dim, out_dim)
def forward(self, g, features):
h_s = features
h_t = features
for l in range(self.L):
with g.local_scope():
g.nodes['item'].data['h_s'] = h_s
g.nodes['item'].data['h_t'] = h_t
# Aggregate target representation of co-purchased out-neighbors
g.apply_edges(func=dgl.function.copy_u('h_t', 'm'), etype='bought_together')
g.update_all(message_func=dgl.function.copy_e('m', 'm'), reduce_func=dgl.function.sum('m', 'h_t_agg'), etype='bought_together')
h_t_agg = g.nodes['item'].data['h_t_agg']
# Aggregate source representation of co-viewed out-neighbors
g.apply_edges(func=dgl.function.copy_u('h_s', 'm'), etype='viewed_together')
g.update_all(message_func=dgl.function.copy_e('m', 'm'), reduce_func=dgl.function.sum('m', 'h_s_agg'), etype='viewed_together')
h_s_agg = g.nodes['item'].data['h_s_agg']
# Update source and target representations
h_sn = F.relu(self.linear_s(h_t_agg + h_s))
h_tn = F.relu(self.linear_t(h_s_agg + h_t))
# Normalize the embeddings to a unit norm
h_sn = F.normalize(h_sn)
h_tn = F.normalize(h_tn)
# Generate final source and target product representation
theta_s = self.out_linear_s(h_sn)
theta_t = self.out_linear_t(h_tn)
return theta_s, theta_t
My doubt is if I’m right on the message passing of the source and target nodes.
Any help here would be great. Thanks