I am trying to implement the loss function found in the GraphSAGE paper (equation 1). Here is my code I am using:
def loss_fn(graph, embeddings, length, node_ids, size_neg_graph):
# get positive scores first
pos_ids = dgl.sampling.random_walk(graph, node_ids, length=length)[0][:, length - 1] # gets the node ids of nodes that occur on RW
pos_score = torch.sum(embeddings * embeddings[pos_ids], dim=1) # performs inner product
pos_score = torch.nn.functional.logsigmoid(pos_score) # log sigmoid
# get negative scores
neg_graph = construct_negative_graph(graph, size_neg_graph) # constructs a negative graph by corrupting destination nodes by sampling randomly from all nodes in graph
neg_score = dp(neg_graph, embeddings) # dot product function to compute dot prod along all edges of the embeddings
neg_score = torch.nn.functional.logsigmoid(-neg_score)
# return loss
loss = -torch.sum(pos_score) - torch.sum(neg_score)
return loss / len(node_ids)
When I run this, the loss never minimises. It always gets stuck at a relatively high number. Is there something wrong with my code?
The DP function and Negative graph sampling function are defined as follows:
class DotProductPredictor(nn.Module):
def forward(self, graph, h):
# h contains the node representations computed from the GNN
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
return graph.edata['score']
def construct_negative_graph(graph, k):
src, dst = graph.edges()
neg_src = src.repeat_interleave(k)
neg_dst = torch.randint(0, graph.number_of_nodes(), (len(src) * k,))
return dgl.graph((neg_src, neg_dst), num_nodes=graph.number_of_nodes())