Hi, my current function to extract node embeddings is the following:
def get_embedding_from_graph(graph, model):
graph_node_indices ={
"User": graph.nodes("User"),
"Action": graph.nodes("Action"),
"Episode": graph.nodes("Episode")
}
nodes_dataloader = dgl.dataloading.DataLoader(
graph=graph,
indices=graph_node_indices,
graph_sampler=dgl.dataloading.MultiLayerFullNeighborSampler(2),
batch_size=batch_size,
device=device,
use_uva=True,
shuffle=False,
)
nodes_embedding = {
"Action": torch.empty([0, num_out]),
"Episode": torch.empty([0, num_out]),
"User": torch.empty([0, num_out])
}
with torch.no_grad():
for input_nodes, sampled_graph, blocks in tqdm(nodes_dataloader):
input_features = generate_features_dict(blocks[0])
emb = model.sage(blocks, input_features)
nodes_embedding = {
node_type: torch.cat(
(nodes_embedding[node_type], emb[node_type].to("cpu"))
)
for node_type in nodes_embedding
}
return nodes_embedding
It seems to work with GraphConv and SAGEConv, but it seems to be not working with GATv2Conv. My GAT sage layer looks like this:
class GraphLayers(nn.Module):
def __init__(
self,
input_nodes_features: Dict[Tuple[str, str, str], torch.Tensor],
graph_edges: Iterable[Tuple[str, str, str]],
hidden_features: int,
out_features: int,
layer_type: str = "graphconv",
**kwargs
):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv(
{
(source, edge, destination): dglnn.GATv2Conv(
in_feats=(
input_nodes_features[source].shape[1],
input_nodes_features[destination].shape[1],
),
out_feats=hidden_features,
activation=nn.ReLU(),
num_heads=kwargs.get("gat_num_heads", 4),
feat_drop=kwargs.get("gat_feat_drop", 0),
attn_drop=kwargs.get("gat_attn_drop", 0),
negative_slope=kwargs.get("get_negative_slope", 0.2),
residual=kwargs.get("gat_residual", False),
share_weights=kwargs.get("gat_share_weights", False),
)
for (source, edge, destination) in graph_edges
},
aggregate=kwargs.get("heteroconv_aggregate", "sum")
)
self.conv2 = dglnn.HeteroGraphConv(
{
(source, edge, destination): dglnn.GATv2Conv(
in_feats=hidden_features,
out_feats=out_feats,
activation=None,
num_heads=kwargs.get("gat_num_heads", 4),
feat_drop=kwargs.get("gat_feat_drop", 0),
attn_drop=kwargs.get("gat_attn_drop", 0),
negative_slope=kwargs.get("get_negative_slope", 0.2),
residual=kwargs.get("gat_residual", False),
share_weights=kwargs.get("gat_share_weights", False),
)
for (source, edge, destination) in graph_edges
},
aggregate=kwargs.get("heteroconv_aggregate", "sum")
)
def forward(self, blocks, inputs):
# inputs are features of nodes
h = self.conv1(blocks[0], inputs)
h = {edge: out_gatconv.mean(dim=1) for edge, out_gatconv in h.items()}
h = self.conv2(blocks[1], h)
h = {edge: out_gatconv.mean(dim=1) for edge, out_gatconv in h.items()}
return h