Confusion regarding DGLGraph.local_var

Documentation

In the API documentation, local_var, I found this statement in the Note.

However, inplace operations do change the shared tensor values, so will be reflected to the original graph.

Could you give an example of what inplace operations will change the shared tensor value?
It is not clear to me and I do not know how this function can be safely used when writing a customized models.

Related Implementation Question:

In the code of gatconv
In forward function, it uses the graph = graph.local_var().
Does this mean that this implementation of gatconv will not tune the graph node features during training?

For example, if I am using node features, I want to tune the feature during training, I need to implement another gatconv, rather than use the version in from dgl.nn.pytorch.GATConv

Thank you very much for your answer.

The same question is also asked at github issue: github
By the way, is it preferred to ask in forum or github?

This question is being answer in github:

Hi Mufei, I also have some following question in the github thread, could you take a look, thanks!

1 Like

Hi Mufei, I also have some following question in the github thread, could you take a look, thanks!

Hi,

I checked the GitHub issue.

Regarding to the question here where you wish to fine-tune the embedding during training, you don’t have to implement another GATConv. Instead, you can simply do something like this (not runnable, just giving you an idea how the code would look like):

class GAT(nn.Module):
    def __init__(self,
                 g, ...):
        super(GAT, self).__init__()
        self.g = g
        ...
        # change the initialization to anything you like
        self.node_embedding = nn.Parameter(torch.randn(g.number_of_nodes(), args.node_embed_size))

    def forward(self):
        h = self.node_embedding
        for l in range(self.num_layers):
            h = self.gat_layers[l](self.g, h).flatten(1)
        # output projection
        logits = self.gat_layers[-1](self.g, h).mean(1)
        return logits