Greetings!
I am trying to code a GCN model to perform multi-class classification. The problem setup is as follows with each node (node feature) having multiple classes (one-hot encoding fashion),
node1_features–>[1, 0, 0]
node2_features–>[0, 1, 0]
I am having issues when dealing with the data loader and the collate function. I have the following setup for the collate function,
def collate(samples):
multi_labels = []
graphs, node_labels = map(list, zip(*samples))
batch_graph = dgl.batch(graphs)
for i in range(len(node_labels)):
for j in range(len(node_labels[i])):
labels.append(node_labels[i][j])
return batch_graph, th.tensor(labels)
The “node_labels” in the above code represent one vector of one-hot encoding ([0, 1, 0])
Also, I am curious what will be the setup for the loss function. Any help would be appreciated!