How to implement graph.update_all() manually?

Here is some code in GCN:

graph.srcdata['h'] = srcfeat
a = fn.u_mul_e('h', 'affine', 'm')
b = fn.sum(msg='m', out='h')
graph.update_all(a, b)
gcn_feat = torch.cat([dstfeat, graph.dstdata['h']], dim=-1)

However, the graph cannot deploy to our NPU card, thus we figure implement graph.update_all() manually with pytorch. But it seems like there is a shape problem to mul src_feat with affine, Any code or suggests will be grately appreciated :pray:

self.affine = nn.Parameter(graph.edata['affine'].to(srcfeat.device))
# performs src_feat(x, 512) ? affine(90, 1) = result(10, 512)

Since graph.edata['affine'] is a tensor with the same number of rows as the number of edges, I don’t think it’s a good idea to make it a parameter. Otherwise, your model will only work on graphs with the same number of edges (i.e. 90).

Instead, I would recommend you compute affine with some other functions (e.g. dot product of the incident node representations).

Thanks, code below works:

        gidx = graph._graph
        src, dst, edge = gidx.edges(0)
        adj = self.build_adj(src, dst, srcfeat)
        out = self.gcn(srcfeat.unsqueeze(0).permute(0, 2, 1), adj.to(srcfeat.device)).permute(0, 2, 1)[0, torch.unique(dst).long(), ...]
        gcn_feat = torch.cat([dstfeat, out], dim=-1)

And, come code:

        #     self.debug_pred_conn()
        #     output_bipartite = bipartites[-1]
        #     output_bipartite.srcdata['conv_features'] = neighbor_x_src.to('cpu')
        #     output_bipartite.dstdata['conv_features'] = center_x_src.to('cpu')
        #
        # output_bipartite.apply_edges(self.pred_conn)
        # output_bipartite.edata['prob_conn'] = F.softmax(output_bipartite.edata['pred_conn'], dim=1)
        # output_bipartite.update_all(self.pred_den_msg, fn.mean('pred_den_msg', 'pred_den'))
        # return output_bipartite

can be replaced by:

            output_bipartite = bipartites[-1]
        # print("forward with raw_affine shape=", output_bipartite.edata['raw_affine'].shape, ", data=", output_bipartite.edata['raw_affine'][:9])
        gidx = output_bipartite._graph
        src, dst, edge = gidx.edges(0)
        src_feat = self.src_mlp(neighbor_x_src[src.long()])
        ###  feature alignment
        len_center_x_src = center_x_src.shape[0]
        len_src_feat = src_feat.shape[0]
        repeat = int(len_src_feat / len_center_x_src)
        center_x_src_r = torch.repeat_interleave(center_x_src, repeat, dim=0)
        ###
        dst_feat = self.dst_mlp(center_x_src_r)
        pred_conn = self.classifier_conn(src_feat + dst_feat)
        output_bipartite.edata['pred_conn'] = pred_conn.to('cpu')  # 保存用于计算loss
        prob = F.softmax(pred_conn, dim=1)
        raw_affine = output_bipartite.edata['raw_affine'].to(device)
        pred_den_msg = raw_affine * (prob[:, 1] - prob[:, 0])
        ### 分组mean
        g_pred_den_msg = torch.cat([pred_den_msg[i:i + repeat].unsqueeze(0) for i in range(0, len(pred_den_msg), repeat)], dim=0)
        ###
        pred_den = torch.mean(g_pred_den_msg, dim=1)
        output_bipartite.dstdata['pred_den'] = pred_den.to('cpu')  # 存到cpu上中转,不会丢失梯度
        return output_bipartite