Hi,
I’m trying to use DGL with Metal Performance Shaders (MPS) framework to run operations on AMD Radeon GPU. When I try to put the graph onto GPU as g.to(“mps”), I get an error below:
anaconda3/envs/pytorch/lib/python3.9/site-packages/dgl/utils/internal.py", line 585, in to_dgl_context
device_type = nd.DGLContext.STR2MASK[F.device_type(ctx)]
KeyError: ‘mps’
I’m able to perform operations on the GPU without DGL using PyTorch 2.2. However, it doesn’t work with DGL 2.0. I’d appreciate it if you could help me run DGL on an AMD GPU using MPS device.
Thank you!