Hi all,
I have an urgent question related to DGLGraph.update_all
function.
Context:
The main process runs DistDGL, with the context of:
dgl.distributed.initialize(args.ip_config)
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
model = DistSAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
Then I use python multiprocessing.Process to initiate a subprocess to run function explain
.
Process(target=explain, args=(args, model,))
I construct a local DGLGraph object in the subprocess and want to run forward pass of the model through the newly-constructed graph. But I find the code stuck at graph.update_all
function.
graph.ndata['h'] = n_feats
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'))
But I can run explain
function successfully either by using threading.Thread
or run it in main process. Therefore, I think the root cause is that update_all
function may rely on some runtime objects that originally initialized in the main process, but unable to be accessed in the subprocess.
As I search for the runtime
object, I find it already deprecated in dgl 0.6.x. And I’m using dgl-cu101==0.6.1 package.
Also, I add import dgl
in the explain
function, but still not works.
So could anyone give me some advice for it? And how can I run my program using multiprocessing?
Thanks a lot for your help!