Hi,
First of all, thank you for such a good support community.
I wanted to check whether I am going in the right direction
My dataset has two types of nodes and two types of edges (similar to the graph given below). I need those weights to be multiplied when messages are passed through a corresponding edge, and I need to predict a link (weight) between user and item (‘user’, ‘click’, ‘item’).
I have provided example code below, please advice if I am doing it correctly.
graph = dgl.heterograph({
('user', 'click', 'item'): (user_src, item_dst),
('item', 'similar', 'item'): (item_src, item_dst)
})
graph.edges['click'].data['click_weight'] = click_weights
graph.edges['similar'].data['similarity_weight'] = similarity_weights
graph.nodes['user'].data['features'] = user_features
graph.nodes['item'].data['features'] = item_features
class HeteroDotProductPredictor(nn.Module):
def forward(self, graph, h, etype):
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
return graph.edges[etype].data['score']
def construct_negative_graph(graph, k, etype):
utype, _, vtype = etype
src, dst = graph.edges(etype=etype)
neg_src = src.repeat_interleave(k)
neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
return dgl.heterograph(
{etype: (neg_src, neg_dst)},
num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})
class RGCN(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(in_feats, hid_feats)
for rel in rel_names}, aggregate='sum')
self.conv2 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(hid_feats, out_feats)
for rel in rel_names}, aggregate='sum')
def forward(self, graph, inputs, edge_weights):
h = self.conv1(graph, inputs, mod_kwargs=edge_weights)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h, mod_kwargs=edge_weights)
return h
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, rel_names):
super().__init__()
self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
self.pred = HeteroDotProductPredictor()
def forward(self, g, neg_g, x, etype, edge_weights):
h = self.sage(g, x, edge_weights)
return self.pred(g, h, etype), self.pred(neg_g, h, etype)
def compute_loss(pos_score, neg_score):
n_edges = pos_score.shape[0]
return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()
k = 5
edge_weights = {
'click': {'click_weight': click_weights},
'similar': {'similarity_weight': similarity_weights},
}
model = Model(10, 20, 5, hetero_graph.etypes, edge_weights)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
node_features = {'user': user_feats, 'item': item_feats}
opt = torch.optim.Adam(model.parameters())
for epoch in range(50):
negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item'))
pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item'))
loss = compute_loss(pos_score, neg_score)
opt.zero_grad()
loss.backward()
opt.step()