A general question about masking in graph neural networks

Hey everyone,

I’m a relative newcomer to machine learning, but I frequently deal with graph-structured data in my research.

If I’ve understood correctly, graph neural networks such as Graph Convolutional Networks can be trained to predict node labels using semi-supervised learning by masking the labels of some nodes.

Now I understand the general principle of masking from computer vision, but I’m having a hard time wrapping my head around its operation in graph neural networks.

If I would like to train a GCN to classify nodes, should I set the final number of neurons to correspond to the number of possible labels, followed by a softmax function?

For making predictions on node labels, should I pass the graph to the network with masked labels and the compare the predictions to the ground truth?

If I would like to train a GCN to classify nodes, should I set the final number of neurons to correspond to the number of possible labels, followed by a softmax function?

Yes.

For making predictions on node labels, should I pass the graph to the network with masked labels and the compare the predictions to the ground truth?

Basically yes. To avoid potential confusion, let me clarify a bit more.

The general scenario is that we have few large graphs where only a subset of nodes have labels and we want to infer the labels of the rest nodes. For example, a new paper appears on ArXiv and we want to assign a tag to it. In a citation graph, the nodes stand for papers and the edges stand for citations. If we know that the papers got cited have tags like “machine learning” and “parallel computing”, then probably this new paper should have tags like “machine learning systems”.

The graph neural networks exploit the information of labeled nodes, node features and graph topology with an operation called message passing/graph convolution. Basically, they update the features of nodes based on the features of node neighbors. In the case of semi-supervised node classification, we first perform message passing/graph convolution over the whole graphs and then the objective function is computed only on a subset of nodes. That says:

  1. We assume all nodes have features, even for the unlabeled ones.
  2. The masked nodes are not involved directly for loss computation and we assume no access to their labels during training. However, their features get implicitly involved with message passing/graph convolution, similar to the scenario we are interested in.
1 Like