Stuck in zero_() shared_mem while using omp_num_threads and copy_()

Here is the code recall the bug…

import argparse
import os
import torch
import dgl
from multiprocessing import Process
from dgl.utils.shared_mem import create_shared_mem_array, get_shared_mem_array

# from mailbox_daemon import start_mailbox_daemon
omp_num_threads = '2'
os.environ['OMP_NUM_THREADS'] = omp_num_threads
os.environ['MKL_NUM_THREADS'] = omp_num_threads

def start_daemon(omp_num_threads):
    os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
    os.environ['MKL_NUM_THREADS'] = str(omp_num_threads)

    import torch
    import dgl
    import random
    import numpy as np
    from dgl.utils.shared_mem import create_shared_mem_array, get_shared_mem_array

    node_memory = get_shared_mem_array('node_memory', torch.Size([9999, 100]), dtype=torch.float32)
    
    print("before zero")
    node_memory.zero_()
    print("after zero")  
        


_a = torch.rand((157475, 172))
a = create_shared_mem_array('a', _a.shape, dtype=_a.dtype)
a.copy_(_a)

node_memory = create_shared_mem_array('node_memory', torch.Size([9999, 100]), dtype=torch.float32)

mailbox_daemon = Process(target=start_daemon, args=(omp_num_threads))
mailbox_daemon.start()

print("Done!")

As you see, the code now will stuck in node_memory.zero() this line. And I found several ways that can make code run successfully:

  1. Set omp_num_threads to '1'.
  2. Remove the code line a.copy_(_a).
  3. Move the code line after mailbox_daemon.start().
  4. Using a smaller _a shape.

Every methods list above make the code run sucessfully. I’m so confused where the problem is.

It may be caused by the default fork strategy. please try with ctx = torch.multiprocessing.get_context('spawn') and wrap the non-function code into a function or place under if __name__ == main``

1 Like

Hi, thank you and it truly sloves the problem. The bug is that when using multiple omp threads, it can not return when based on ‘fork’ strategy.
Thank you for your reply : )