Question: What is the proper way of storing embeddings outside of GPU and only loading to GPU the ones beings used, without repetition, during training?
Context: I am using a biological knowledge graph. When I am predicting something like a GENE-DRUG interaction, I sample a n-hops subgraph around this node pair, and then classify the subgraph. My goal is to do graph batching. I am starting with 1-hop subgraph, which usually has around 600 nodes and 30k edges. The whole graph has over 300.000 nodes.
Lets suppose that my batch has 2 subgraphs of 600 nodes each, and they have 100 nodes in common. Where do I store the embeddings so that backpropagation updates the embeddings properly and GPU memory is not wasted?
- In the node?
- In the dataloader?
- In the model?
Thanks you so much for your time! I have been reading the documentation for a while and cant figure it out. If it is not a big deal, would you please write a quick example like the one below?
import dgl.nn.pytorch as dglnn import torch.nn as nn import torch class Classifier(nn.Module): def __init__(self, in_dim, hidden_dim, n_classes): super(Classifier, self).__init__() self.conv1 = dglnn.GraphConv(in_dim, hidden_dim) self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim) self.classify = nn.Linear(hidden_dim, n_classes) def forward(self, g, h): # Apply graph convolution and activation. h = F.relu(self.conv1(g, h)) h = F.relu(self.conv2(g, h)) with g.local_scope(): g.ndata['h'] = h # Calculate graph representation by average readout. hg = dgl.mean_nodes(g, 'h') return self.classify(hg) import dgl.data dataset = dgl.data.GINDataset('MUTAG', False) from dgl.dataloading import GraphDataLoader dataloader = GraphDataLoader( dataset, batch_size=1024, drop_last=False, shuffle=True) import torch.nn.functional as F # Only an example, 7 is the input feature size model = Classifier(7, 20, 5) opt = torch.optim.Adam(model.parameters()) for epoch in range(20): for batched_graph, labels in dataloader: feats = batched_graph.ndata['attr'] logits = model(batched_graph, feats) loss = F.cross_entropy(logits, labels) opt.zero_grad() loss.backward() opt.step()