Mmm it wasn’t my function its dgl message passing function.
Function is:
graph.apply_edges(fn.u_dot_v('h')
Mmm it wasn’t my function its dgl message passing function.
Function is:
graph.apply_edges(fn.u_dot_v('h')
It’s the same as torch.dot
. See the example below.
import dgl
import dgl.function as fn
import torch
g = dgl.graph(([0, 1], [1, 2]))
g.ndata['h'] = torch.tensor([[1., 2.], [3., 4.], [5., 6.]])
g.apply_edges(fn.u_dot_v('h', 'h', 'dot_out'))
print(g.edata['dot_out'])
# tensor([[11.],
# [39.]])
print(torch.dot(g.ndata['h'][0, :], g.ndata['h'][1, :]))
# tensor(11.)
print(torch.dot(g.ndata['h'][1, :], g.ndata['h'][2, :]))
# tensor(39.)
Thanks, I now understand it. Then must be something else. Do you see something wrong with the following code? Maybe Im using different embeddings for each case?
feats = grafo.ndata['Feats']
updated_feats = model.sage(grafo, {'ent': feats})
updated_feats = updated_feats.values()
updated_feats = list(updated_feats)[0]
embeddings = model.pred.etype_project[('ent', 'link3', 'ent')](updated_feats)
pred = model.pred(grafo, {'ent': embeddings}, ('ent', 'link3', 'ent'))
print(pred)
This gives me the pos_score for link3 edges, and look like this:
tensor([[213.6073],
[ 15.9547],
[ 25.2864],
[ 16.5195],
[ 9.9416], grad_fn=<GSDDMMBackward>)
This is used during training. However, for prediction I do the following. I calculate the same dot product between link3 edges with the following:
src, dst = grafo.edges(etype=('ent', 'link3', 'ent'))
for i in range(len(src)):
print(torch.dot(embeddings[src[i], :], embeddings[dst[i], :]))
I get:
tensor(3274.8743, grad_fn=<DotBackward>)
tensor(240.6167, grad_fn=<DotBackward>)
tensor(412.9496, grad_fn=<DotBackward>)
tensor(238.0706, grad_fn=<DotBackward>)
tensor(128.2873, grad_fn=<DotBackward>)
Should these dot values be the same? I really dont see where I could be using the wrong embeddings…
Perhaps model.pred
also called model.pred.etype_project
before computing the dot product.
That was the problem, thanks!
This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.