Hi there,
I am trying a link prediction model on the following graph:
Graph(num_nodes={‘ent’: 12341},
num_edges={(‘ent’, ‘like’, ‘ent’): 11597, (‘ent’, ‘neutral’, ‘ent’): 58975, (‘ent’, ‘dislike’, ‘ent’): 8033},
metagraph=[(‘ent’, ‘ent’, ‘like’), (‘ent’, ‘ent’, ‘neutral’), (‘ent’, ‘ent’, ‘dislike’)])
My model is the following:
Define a Heterograph Conv model
class RGCN(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(in_feats, hid_feats)
for rel in rel_names}, aggregate='sum')
self.conv2 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(hid_feats, out_feats)
for rel in rel_names}, aggregate='sum')
def forward(self, graph, inputs):
# inputs are features of nodes
h = self.conv1(graph, inputs)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h)
return h
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, rel_names):
super().__init__()
self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
self.pred = HeteroDotProductPredictor()
def forward(self, g, neg_g, x, etype):
h = self.sage(g, x)
return self.pred(g, h, etype), self.pred(neg_g, h, etype)
When I try to compute loss with this:
def compute_loss(pos_score, neg_score):
# Margin loss
n_edges = pos_score.shape[0]
return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()
k = 5
num_nodes = g.num_nodes(‘ent’)
embeddings_dimensions = len(g.ndata[‘Feats’][1])
node_types = 1
model = Model(embeddings_dimensions, num_nodes, node_types, g.etypes)
node_feats = g.ndata[‘Feats’]
node_features = {‘ent’: node_feats}
opt = torch.optim.Adam(model.parameters())
for epoch in range(1):
negative_graph = construct_negative_graph(g, k, ('ent', 'like, 'ent'))
pos_score, neg_score = model(g, negative_graph, node_features, ('ent', 'like', 'ent'))
loss = compute_loss(pos_score, neg_score)
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
I get the following error:
AssertionError: The HeteroNodeDataView has only one node type. please pass a tensor directly
Could anybody tell which the problem is?
Thank you so much.