@BarclayII
In my use case, I have nodes type A and B in a bipartite graph. I would like that the massage pass occurs between A <-> B, but I also would like that my loss function compare the embeddings from nodes type A and the k most important neighbors of the same type of node and vice versa (using PinSAGE sampling for example).
The big difference in this approach is that I don’t want to build projection graphs. If, for example, I use PinSAGE in my data loader the message passing will be in the original graph or in the projected graph?
My use case:
1 - Build a full heterogeneous graph with node types A and B
2 - Use all the graph to message passing; in a way nodes type A will receive messages only from nodes type B, and nodes type B will receive only messages from nodes type A
3 - Compute a loss score based on margin loss, comparing the dot product of the representation of nodes of each type with the most similar nodes of the same type (PinSAGE sample) and negative examples.
My first try to solve it without success:
class NeighborhoodLoss():
def __init__(self, graph):
self.graph = graph
def sample_positive_neighbors(self, target_node_type, aux_node_type):
sampler = dgl.sampling.PinSAGESampler(
self.graph,
target_node_type,
aux_node_type,
3,
0.5,
100,
5
)
seeds = torch.LongTensor(self.graph.nodes(target_node_type))
frontier = sampler(seeds)
nodes_ids = frontier.all_edges(form='uv')[0]
neighbors_ids = frontier.all_edges(form='uv')[1]
return nodes_ids, neighbors_ids
def sample_negative_neighbors(self, target_node_type, neg_samples=3):
n_nodes = graph.number_of_nodes(target_node_type)
nodes_ids = torch.from_numpy(np.array(list(np.arange(0, n_nodes)) * neg_samples))
neg_neighbors_ids = torch.randint(0, n_nodes, (n_nodes* neg_samples, ))
return nodes_ids, neg_neighbors_ids
def calculate_representation_dot_product(self, node_ids, neighbors_ids, target_node_type):
nodes_representations = [self.graph.nodes[target_node_type].data['h'][node_id] for node_id in node_ids]
neighbors_representations = [self.graph.nodes[target_node_type].data['h'][node_id] for node_id in neighbors_ids]
n_rep_tensor = torch.FloatTensor(neighbors_representations)
neihgh_rep_tensor = torch.FloatTensor(nodes_representations)
dot_product = torch.dot(n_rep_tensor, neihgh_rep_tensor)
return dot_product
def get_dot_products(self, node_types):
types = [node_types, [node_types[1], node_types[0]] ]
loss = []
for it in types:
pos_nodes_ids, pos_neigh_ids = self.sample_positive_neighbors(it[0],it[1])
neg_nodes_ids, neg_neigh_ids = self.sample_negative_neighbors(it[0])
pos_dot_product = self.calculate_representation_dot_product(pos_nodes_ids, pos_neigh_ids, it[0])
neg_dot_product = self.calculate_representation_dot_product(neg_nodes_ids, neg_neigh_ids, it[0])
loss.append((pos_dot_product, neg_dot_product))
return loss
Score function:
def compute_score(self, pair_graph, user_embeddings, item_embeddings):
with pair_graph.local_scope():
pair_graph.nodes['user'].data['h'] = user_embeddings
pair_graph.nodes['item'].data['h'] = item_embeddings
ngh_loss = NeighborhoodLoss(pair_graph)
scores = ngh_loss.get_dot_products(['item','user'])
return scores
Loss function:
def compute_margin_loss(scores):
pos_score = scores[0]
neg_score = scores[1]
return (- pos_score + neg_score + 0.1).clamp(min=0)