I have the following definition of a heterogeneous graph:
import torch
import dgl
subsub_src = [10, 8, 5, 7, 6, 0, 1, 2, 3]
subsub_dst = [11, 9, 9, 9, 9, 1, 4, 3, 4]
msrmsr_src = [1, 2, 3]
msrmsr_dst = [0, 0, 0]
submsr_src = [4, 9, 11]
submsr_dst = [1, 2, 3]
data_dict = {
('subsubwatershed', 'flows', 'subsubwatershed'): (subsub_src, subsub_dst),
('measurement', 'flows', 'measurement'): (msrmsr_src, msrmsr_dst),
('subsubwatershed', 'in', 'measurement'): (submsr_src, submsr_dst)
}
g = dgl.heterograph(data_dict)
print(g)
Each different type of node and edge has a different type of feature set that I wish to add to the graph as well. What would be the easiest way to load that in. For example in the measurement nodes I wish to add the following information:
Centroid_X Centroid_Y altitude
node
0 123 321 11
1 234 432 22
2 345 543 33
3 456 654 44
In the dgl heterogeneous graph user guide it is stated that you can simply add features using g.nodes['measurement'].data['hv'] = th.ones(3, 1)
for example.
This data is of course randomly generated, but I don’t see an easy of constructing this from data loaded from disk as in my usecase.
It seems like I almost will have to hardcode every feature to be in the proper order based on my graph, how do I go about this easily? Is there some sort of built-in function to achieve this?
EDIT:
I wrote the following function to do it.
import copy
def add_node_features(graph: dgl.DGLGraph, node_type: str, features: pd.DataFrame):
new_graph = copy.deepcopy(graph)
for feature in features:
new_graph.nodes[node_type].data[feature] = (torch.from_numpy(features[feature].to_numpy()))
return new_graph
static_measurement_nodes_df = pd.DataFrame(
{
"Centroid_X": [123, 234, 345, 456],
"Centroid_Y": [432, 432, 543, 654],
"altitude": [11, 22, 33, 44]
}
)
g = add_node_features(g, "measurement", static_measurement_nodes_df)
print(g.nodes["measurement"])
But this of course assumes that the order of the given dataframe is in the same order as in whatever order the .nodes attribute returns. What if these orders are not the same?
And how do I go about checking that indeed the correct nodes have the correct associated data?