Message Function for Multiplication with Weight Matrix

Hi there!

I would like to have a trainable weight matrix for each edge type in my graph and use this matrix in my ScorePredictor class, similiar as in this DGL Tutorial 6.3.

I would like to multiply my source node feature z_i with the weight matrix M_r for that edge type and the target node feature z_j as in the following formula:
grafik

I have been looking through your DGL Message Functions but I haven’t found one which I could use for my idea. Since I do not have a feature vector for every edge, but a matrix for an edge type, the nomal u_mul_e, u_dot_e functions do not really work.

Is there any way I can use the built-in functions?

Thank you for any help in advance!

You can do something as follows.

import dgl
import dgl.function as fn
import torch
import torch.nn as nn

g = dgl.heterograph({
          ('A', 'r1', 'B'): ([0, 1], [1, 2])
    })
for nty in g.ntypes:
    g.nodes[nty].data['h'] = torch.randn(g.num_nodes(nty), 2)

def bilinear_msg(bilinear):
    def msg_func(edges):
        return {'m': bilinear(edges.src['h'], edges.dst['h'])}
    return msg_func

class Model(nn.Module):
    def __init__(self, in_feats=2, out_feats=3):
        super(Model, self).__init__()
        self.rel_weights = nn.ModuleList([nn.Bilinear(in_feats, in_feats, out_feats)])

    def forward(self, g):
        g.update_all(bilinear_msg(self.rel_weights[0]), fn.sum('m', 'h_new'), etype=g.canonical_etypes[0])

model = Model()
model(g)
1 Like

Thank you! @mufeili :slight_smile:

@mufeili Considering I do not have a Bilinear function (which is a bit confusing for me), but I have a feature matrix for every edge type that changes with every forward pass, how could I multiply the edge type feature matrix with the source and destination node feature?

You can imagine the code I have to be like this:

import dgl
import torch as th
import torch.nn as nn

class ScorePredictor(nn.Module):
    def __init__(
        self,...
    ):
    ...

       def forward(
        self,
        edge_subgraph: dgl.DGLHeteroGraph,
        node_feature_tensor: Dict[str, th.Tensor],
        edge_feature_tensor: Dict[str, th.Tensor],
    ) -> Dict[Tuple[str, str, str], Dict[str, th .Tensor]]:
        """Perform score prediction.

        Args:
            edge_subgraph (dgl.DGLHeteroGraph): The subgraph to be evaluated.
            node_feature_tensor (th.Tensor): Node feature tensor for every node type .
            edge_feature_tensor (th.Tensor): Edge feature tensor for every edge type.

        Returns:
            results_dict (dict): A dictionary of canonical edge types mapping to the score of the corresponding edges.

        """
        result_dict = {}
        with edge_subgraph.local_scope():
            # set node feature
            edge_subgraph.ndata['x'] = node_feature_tensor
            #set edge feature which is one matrix for one edge type
            edge_subgraph.edata['edge_feature'] = {
                etype: th.diag(
                    input=th.sum(edge_feature_tensor[etype], dim=1)
                )
                for etype in edge_feature_tensor
            }
         
            for etype in edge_subgraph.etypes:
           
                # TODO: insert weight_matrix as edge feature (same for every edge of one edge type)

                # first part: multiply source node feature with edge type feature
                edge_subgraph.apply_edges(
                    dgl.function.u_mul_e('x', 'edge_feature', 'tmp'),
                    etype=etype,
                )
                # second part: multiply intermediate product with target node feature
                edge_subgraph.apply_edges(
                    dgl.function.e_mul_v('tmp', 'x', 'rawscore_{}'.format(etype)),
                    etype=etype,
                )

                result_dict.update(edge_subgraph.edata['score_{}'.format(etype)])
            return result_dict

@mufeili I updated the code snipped a bit .What you can see now is the approach I tried, and which failed: I tried to multiply source node feature with the edge type feature (u_mul_e) and then the result with the destination node feature (e_mul_v). The problem here is that the edge feature is a matrix (of dimension number of edges of this edge type x hidden dimension) that is the same for all edges of one edge type, so I do not have individual edge features for each single edge. Due to this, the multiplication as I did it does not work.

So what would be really helpful, is a multiplication with a edge type specific parameter. Could you help me with that?

You can do something as follows.

import dgl
import dgl.function as fn
import torch
import torch.nn as nn

g = dgl.heterograph({
          ('A', 'r1', 'B'): ([0, 1], [1, 2])
    })

def bilinear_msg(eweight):
    def msg_func(edges):
        src_x, dst_x = edges.src['x'], edges.dst['x']
        return {'score': torch.einsum('bn,nm,bm->b', src_x, eweight, dst_x)}
    return msg_func

class Model(nn.Module):
    def __init__(self, in_feats=2, out_feats=3):
        super(Model, self).__init__()
        self.rel_weights = nn.ModuleList([nn.Bilinear(in_feats, in_feats, out_feats)])

    def forward(self, g, node_feature_tensor, edge_feature_tensor):
        result_dict = {}
        with g.local_scope():
            g.ndata['x'] = node_feature_tensor
            for etype in g.canonical_etypes:
                g.apply_edges(bilinear_msg(edge_feature_tensor[etype]), etype=etype)
                result_dict[etype] = g.edges[etype].data['score']
        return result_dict

model = Model()
node_feature_tensor = {'A': torch.randn(g.num_nodes('A'), 2), 'B': torch.randn(g.num_nodes('B'), 3)}
edge_feature_tensor = {('A', 'r1', 'B'): torch.randn(2, 3)}
result = model(g, node_feature_tensor, edge_feature_tensor)
1 Like

Thank you @mufeili, this is awesome! It worked! And I really like the th.einsum function, didn’t know it before.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.