How to use GATConv in dgl.nn.HeteroGraphConv

# Define a Heterograph Conv model


from typing_extensions import final
from numpy import greater
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl

from dgl.nn.pytorch import GATConv,SAGEConv 



class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        # 实例化HeteroGraphConv,in_feats是输入特征的维度,out_feats是输出特征的维度,aggregate是聚合函数的类型
        self.conv1 = dgl.nn.HeteroGraphConv({
            rel: GATConv (in_feats, hid_feats,8)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dgl.nn.HeteroGraphConv({
            rel: dgl.nn.GraphConv(hid_feats * 8, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # 输入是节点的特征字典
        #t_input=tuple()

        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h


if __name__ == '__main__':
    n_users = 1000
    n_items = 500
    n_follows = 3000
    n_clicks = 5000
    n_dislikes = 500
    n_hetero_features = 10
    n_user_classes = 5
    n_max_clicks = 10

    follow_src = np.random.randint(0, n_users, n_follows)
    follow_dst = np.random.randint(0, n_users, n_follows)
    click_src = np.random.randint(0, n_users, n_clicks)
    click_dst = np.random.randint(0, n_items, n_clicks)
    dislike_src = np.random.randint(0, n_users, n_dislikes)
    dislike_dst = np.random.randint(0, n_items, n_dislikes)

    hetero_graph = dgl.heterograph({
        ('user', 'follow', 'user'): (follow_src, follow_dst),
        ('user', 'followed-by', 'user'): (follow_dst, follow_src),
        ('user', 'click', 'item'): (click_src, click_dst),
        ('item', 'clicked-by', 'user'): (click_dst, click_src),
        ('user', 'dislike', 'item'): (dislike_src, dislike_dst),
        ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})

    hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)
    hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)
    hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))
    hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()

    hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)
    hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)




    model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)


    user_feats = hetero_graph.nodes['user'].data['feature']
    item_feats = hetero_graph.nodes['item'].data['feature']
    labels = hetero_graph.nodes['user'].data['label']
    train_mask = hetero_graph.nodes['user'].data['train_mask']



    # user_feats [1000,10]
    # item_feats [500,10]
    node_features = {'user': user_feats, 'item': item_feats}
    confeature=torch.cat((user_feats,item_feats),dim=0)

    h_dict = model(hetero_graph,confeature)
    h_user = h_dict['user']
    h_item = h_dict['item']





    opt = torch.optim.Adam(model.parameters())

    print(model)
    print(hetero_graph)
    for epoch in range(5):
        model.train()
        # 使用所有节点的特征进行前向传播计算,并提取输出的user节点嵌入
        logits = model(hetero_graph, node_features)['user']
        # 计算损失值
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        # 计算验证集的准确度。在本例中省略。
        # 进行反向传播计算
        opt.zero_grad()
        loss.backward()
        opt.step()
        print(loss.item())

        # 如果需要的话,保存训练好的模型。本例中省略。





This is my code . But it always reminds me that using GATConv in HeteroGraphConv will produce an error in the number of nodes

Hi,

You can pass the dict to hetero model. Line h_dict = model(hetero_graph,confeature) should change to h_dict = model(hetero_graph, node_features)

And the output of GATConv is [batch_size, hidden_dim, num_heads], you need to flat the later two dimension to pass it to the next GraphConv modules. Below is the code I fixed from yours



from typing_extensions import final
from numpy import greater
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import dgl

from dgl.nn.pytorch import GATConv,SAGEConv 



class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        # 实例化HeteroGraphConv,in_feats是输入特征的维度,out_feats是输出特征的维度,aggregate是聚合函数的类型
        self.conv1 = dgl.nn.HeteroGraphConv({
            rel: GATConv(in_feats, hid_feats, 8)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dgl.nn.HeteroGraphConv({
            rel: dgl.nn.GraphConv(hid_feats*8, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # 输入是节点的特征字典
        #t_input=tuple()

        h = self.conv1(graph, inputs)
        h = {k: torch.reshape(F.relu(v), (v.shape[0], -1)) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

def main():
    n_users = 1000
    n_items = 500
    n_follows = 3000
    n_clicks = 5000
    n_dislikes = 500
    n_hetero_features = 10
    n_user_classes = 5
    n_max_clicks = 10

    follow_src = np.random.randint(0, n_users, n_follows)
    follow_dst = np.random.randint(0, n_users, n_follows)
    click_src = np.random.randint(0, n_users, n_clicks)
    click_dst = np.random.randint(0, n_items, n_clicks)
    dislike_src = np.random.randint(0, n_users, n_dislikes)
    dislike_dst = np.random.randint(0, n_items, n_dislikes)

    hetero_graph = dgl.heterograph({
        ('user', 'follow', 'user'): (follow_src, follow_dst),
        ('user', 'followed-by', 'user'): (follow_dst, follow_src),
        ('user', 'click', 'item'): (click_src, click_dst),
        ('item', 'clicked-by', 'user'): (click_dst, click_src),
        ('user', 'dislike', 'item'): (dislike_src, dislike_dst),
        ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})

    hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)
    hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)
    hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))
    hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()

    hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)
    hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)




    model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)


    user_feats = hetero_graph.nodes['user'].data['feature']
    item_feats = hetero_graph.nodes['item'].data['feature']
    labels = hetero_graph.nodes['user'].data['label']
    train_mask = hetero_graph.nodes['user'].data['train_mask']



    # user_feats [1000,10]
    # item_feats [500,10]
    node_features = {'user': user_feats, 'item': item_feats}
    confeature=torch.cat((user_feats,item_feats),dim=0)

    h_dict = model(hetero_graph,node_features)
    h_user = h_dict['user']
    h_item = h_dict['item']





    opt = torch.optim.Adam(model.parameters())

    print(model)
    print(hetero_graph)
    for epoch in range(5):
        model.train()
        # 使用所有节点的特征进行前向传播计算,并提取输出的user节点嵌入
        logits = model(hetero_graph, node_features)['user']
        # 计算损失值
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        # 计算验证集的准确度。在本例中省略。
        # 进行反向传播计算
        opt.zero_grad()
        loss.backward()
        opt.step()
        print(loss.item())

        # 如果需要的话,保存训练好的模型。本例中省略。



if __name__ == '__main__':
    main()
1 Like

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.