I basically ended up doing it like this, and it seems to be working now. Can you help me verify that it is correct?
class TestRGCN(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, canonical_etypes):
super(TestRGCN, self).__init__()
self.conv1 = dglnn.HeteroGraphConv({
etype : dglnn.GraphConv(in_feats[utype], hid_feats, norm='right')
for utype, etype, vtype in canonical_etypes
})
self.conv2 = dglnn.HeteroGraphConv({
etype : dglnn.GraphConv(hid_feats, out_feats, norm='right')
for _, etype, _ in canonical_etypes
})
def forward(self, blocks, inputs):
x = self.conv1(blocks[0], inputs)
x = self.conv2(blocks[1], x)
return x
class HeteroScorePredictor(nn.Module):
def forward(self, edge_subgraph, x):
with edge_subgraph.local_scope():
edge_subgraph.ndata['h'] = x
for etype in edge_subgraph.canonical_etypes:
edge_subgraph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)
# edge_subgraph.apply_edges(self.apply_edges, etype=etype)
return edge_subgraph.edata['score']
class TestModel(nn.Module):
# here we have a model that first computes the representation and then predicts the scores for the edges
def __init__(self, in_features, hidden_features, out_features, canonical_etypes):
super().__init__()
self.sage = TestRGCN(in_features, hidden_features, out_features, canonical_etypes)
self.pred = HeteroScorePredictor()
def forward(self, g, neg_g, blocks, x):
x = self.sage(blocks, x)
pos_score = self.pred(g, x)
neg_score = self.pred(neg_g, x)
return pos_score, neg_score
def compute_loss(pos_score, neg_score, canonical_etypes):
# Margin loss
all_losses = []
for given_type in canonical_etypes:
n_edges = pos_score[given_type].shape[0]
if n_edges == 0:
continue
all_losses.append((1 - neg_score[given_type].view(n_edges, -1) + pos_score[given_type].unsqueeze(1)).clamp(min=0).mean())
return torch.stack(all_losses, dim=0).mean()
model = TestModel(in_features={'source':700, 'user':800}, hidden_features=512, out_features=256, canonical_etypes=g.canonical_etypes)
...
for epoch in range(args.n_epochs):
for input_nodes, positive_graph, negative_graph, blocks in dataloader:
model.train()
blocks = [b.to(torch.device('cuda')) for b in blocks]
positive_graph = positive_graph.to(torch.device('cuda'))
negative_graph = negative_graph.to(torch.device('cuda'))
node_features = {'source': blocks[0].srcdata['source_embedding']['source'], 'user': blocks[0].srcdata['user_embedding']['user']}
pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features)
loss = compute_loss(pos_score, neg_score, g.canonical_etypes)
optimizer.zero_grad()
loss.backward()
optimizer.step()
My input graph basically has a structure like this (with more nodes of course):
data_dict = {('source', 'has_follower', 'user'): (torch.tensor([0, 0]), torch.tensor([0, 0])), ('user', 'follows', 'source'): (torch.tensor([0, 0]), torch.tensor([0, 0]))}
And node has an embedding:
dgl_graph.nodes['source'].data['source_embedding'] = torch.zeros(1, 700)
dgl_graph.nodes['user'].data['user_embedding'] = torch.zeros(1, 800)