Set initial values for learning embeddings

I’m trying to train a heterogeneous graph network that implements margin loss and minimizes the distance among embeddings of similar nodes and maximizes distance among dissimilar nodes (negative sampling).

In my example, I have nodes type ‘item’ and ‘user’. These two types of nodes have just one attribute ‘h’, and this vector is created based on a concatenation of word representations (word2vec) and image representation (VGG16) [inspired on PinSage paper]. Also, nodes type ‘item’ and ‘user’ have different embedding dimensions (h). Finally, in my network, ‘item’ receives messages only from ‘user’ and vice versa.

How can I start with these initial embedding vectors ‘h’ and then learn new ones based on the GCN process? I’m facing some troubles with the dimensionality in the nn.Linear methods on my convolutional layer.

class BAMGConv(nn.Module):
    def __init__(self, w_dim, b_dim):
        super().__init__()

        self.linear_w = nn.Linear(in_features=w_dim, out_features=w_dim, bias=False)
        self.linear_b = nn.Linear(in_features=b_dim, out_features=w_dim, bias=False)
        
    def compute_message(self, edges):
        affinity = edges.data['weight'] / torch.sum(edges.data['weight'])
        return affinity

    def forward(self, graph, node_features):
        with graph.local_scope():
            src_features, dst_features = node_features
            graph.srcdata['h'] = src_features
            graph.dstdata['h'] = dst_features
      
            graph.apply_edges(lambda edges: {'a': self.compute_message(edges)})

            graph.update_all(message_func=fn.u_mul_e('h', 'a', 'h_ngh'),
                             reduce_func=fn.sum('h_ngh', 'neighbors_avg'))
            
            result = F.relu(self.linear_w(graph.dstdata['h']) + self.linear_b(graph.dstdata['neighbors_avg']))
            return result
class BAMGLayer(nn.Module):
    def __init__(self, user_hidden_dim, item_hidden_dim):
        super().__init__()
        
        self.heteroconv = dglnn.HeteroGraphConv(
            # NN Module fo message passing on each individual edge type
            {'watched': BAMGConv(user_hidden_dim, item_hidden_dim), 'watched-by': BAMGConv(item_hidden_dim, user_hidden_dim)},
            aggregate='sum')
        
    def forward(self, block, input_user_features, input_item_features):
        with block.local_scope():
            h_user = input_user_features
            h_item = input_item_features
            
            src_features = {'user': h_user, 'item': h_item}
            # copy the features from the source side to the destination side.:
            dst_features = {'user': h_user[:block.number_of_dst_nodes('user')], 'item': h_item[:block.number_of_dst_nodes('item')]}
            
            # HeteroGraphConv essentially performs a heterogeneous graph convolution by:
            # (1) Performing message passing on each edge type individually.  In this case,
            #     * Items receives messages from users along "watched" relation.
            #     * Users receives messages from items along "watched-by" relation.
            # (2) Aggregating the result on each node type along incoming edge types.  In this case,
            result = self.heteroconv(block, (src_features, dst_features))
            return result['user'], result['item']
class BAMGScorer(nn.Module):
    def __init__(self, num_users, num_items, user_hidden_dim, item_hidden_dim, final_hidden_dim, num_layers=1):
        super().__init__() 
        self.layers = nn.ModuleList([
            BAMGLayer(user_hidden_dim, item_hidden_dim) for _ in range(num_layers)])
        
        self.U_h = nn.Embedding(user_hidden_dim, user_hidden_dim)
        self.I_h = nn.Embedding(item_hidden_dim, item_hidden_dim)

    def forward(self, blocks):
        user_embeddings = blocks[0].srcnodes['user'].data['h']
        item_embeddings = blocks[0].srcnodes['item'].data['h']

        user_embeddings = self.U_h(user_embeddings)
        item_embeddings = self.I_h(item_embeddings)

        for block, layer in zip(blocks, self.layers):
            user_embeddings, item_embeddings = layer(block, user_embeddings, item_embeddings)
  
        return user_embeddings, item_embeddings
        
    def compute_score(self, pair_graph, user_embeddings, item_embeddings):
        with pair_graph.local_scope():
            pair_graph.nodes['user'].data['h'] = user_embeddings
            pair_graph.nodes['item'].data['h'] = item_embeddings
            ngh_scorer = NeighborhoodScore(pair_graph)
            scores = ngh_scorer.get_similarity_scores(['item','user'])
            return scores

Take the edge type ('user', 'watched', 'item') as an example.

As far as I understand, linear_w is computing the transformation of the item’s own features, and linear_b is computing the transformation of the neighboring users’ features. So linear_w should have input dimension item_hidden_dim and output dimension item_hidden_dim, while linear_b should have input dimension user_hidden_dim and output dimension item_hidden_dim.

Therefore, the initialization

is wrong. Switch the place of user_hidden_dim and item_hidden_dim and the code should work fine.

1 Like