I am trying to use torch.nn.DataParallel (https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html) to speed up the training of my TreeLSTM model in the tutorial (https://docs.dgl.ai/en/latest/tutorials/models/2_small_graph/3_tree-lstm.html). The major parts I changed are as follows:
=======
device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)
print('We have ', torch.cuda.device_count(), ‘GPUs!’)
model = TreeLSTM(trainset.num_vocabs, x_size, h_size, trainset.num_classes, dropout)
model = torch.nn.DataParallel(model)
model.to(device)
=======
But I always got the following error:
======
AttributeError: ‘tuple’ object has no attribute ‘graph’
(which points to this line of code: “g = batch.graph”)
======
Any suggestions or comments on this issue?