Currently I am using the GAT model that comes with DGL. Now I hope that the parameter ‘g’ of the input model is a weighted graph. During the forward propagation process, the node will take the weight of the edges into account when aggregating neighbor information.
class GAT(nn.Module):
def __init__(self, in_feats, hidden_feats, n_classes, activation,dropout):
super(GAT, self).__init__()
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList()
self.layers.append(GATConv(in_feats, hidden_feats, num_heads=4, activation=activation))
self.layers.append(GATConv(hidden_feats * 4, n_classes, num_heads=1, activation=None))
def forward(self, g, features):
x = features
for layer in self.layers[:-1]:
x = layer(g, x).flatten(1)
x = self.dropout(x)
x = self.layers[-1](g, x).flatten(1)
return x