How to point device for graph in lightning fabric?

I’m using lightning-fabric to train and evaluate my GNN model,and I use GraphDataloader to load my dataset.
And this is my module code now:

class EncoderMinLSTMGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, graph, seq_length, dropout=0.1, num_heads=4):
        super(EncoderMinLSTMGNN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.seq_length = seq_length
        self.ln0 = LayerNorm(input_dim)
        self.graph = dgl.batch([graph] * seq_length)
        self.gnn0 = gatv2conv.GATv2Conv(input_dim, input_dim, num_heads=num_heads)
        self.gnn1 = gatv2conv.GATv2Conv(input_dim * num_heads, input_dim, num_heads=1)
        self.dropout = nn.Dropout(dropout)
        self.ln1 = LayerNorm(input_dim)
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = torch.concat([x[:, :, 0].unsqueeze(-1), torch.log(x[:, :, 1:]+1)], dim=-1)
        self.graph = self.graph.to(x.device)
        x = torch.where(x.isnan(), 0, x)
        gnn_outputs0 = self.gnn0(self.graph, feat=x)
        gnn_outputs1 = self.gnn1(self.graph, feat=gnn_outputs0.reshape(gnn_outputs0.shape[0], -1))
        gnn_outputs1 = gnn_outputs1.reshape(self.hidden_dim, self.seq_length, self.input_dim)
        gnn_outputs = x + self.ln1(gnn_outputs1)
        gnn_outputs = self.dropout(gnn_outputs)
        gnn_outputs = self.fc(gnn_outputs)
        return gnn_outputs

For making tensors and batched graph in the same GPU, I put graph.to(device) in forward method, but I think is not a good achievement.
Is there easier method to let lightning.Fabric object transfer graph to GPU automatically and don‘t need to to(device) in forward method?