How to point device for graph in lightning fabric?

I’m using lightning-fabric to train and evaluate my GNN model,and I use GraphDataloader to load my dataset.
And this is my module code now:

class EncoderMinLSTMGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, graph, seq_length, dropout=0.1, num_heads=4):
        super(EncoderMinLSTMGNN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.seq_length = seq_length
        self.ln0 = LayerNorm(input_dim)
        self.graph = dgl.batch([graph] * seq_length)
        self.gnn0 = gatv2conv.GATv2Conv(input_dim, input_dim, num_heads=num_heads)
        self.gnn1 = gatv2conv.GATv2Conv(input_dim * num_heads, input_dim, num_heads=1)
        self.dropout = nn.Dropout(dropout)
        self.ln1 = LayerNorm(input_dim)
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = torch.concat([x[:, :, 0].unsqueeze(-1), torch.log(x[:, :, 1:]+1)], dim=-1)
        self.graph = self.graph.to(x.device)
        x = torch.where(x.isnan(), 0, x)
        gnn_outputs0 = self.gnn0(self.graph, feat=x)
        gnn_outputs1 = self.gnn1(self.graph, feat=gnn_outputs0.reshape(gnn_outputs0.shape[0], -1))
        gnn_outputs1 = gnn_outputs1.reshape(self.hidden_dim, self.seq_length, self.input_dim)
        gnn_outputs = x + self.ln1(gnn_outputs1)
        gnn_outputs = self.dropout(gnn_outputs)
        gnn_outputs = self.fc(gnn_outputs)
        return gnn_outputs

For making tensors and batched graph in the same GPU, I put graph.to(device) in forward method, but I think is not a good achievement.
Is there easier method to let lightning.Fabric object transfer graph to GPU automatically and don‘t need to to(device) in forward method?

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.