DGL with MPS device

Hi,

I’m trying to use DGL with Metal Performance Shaders (MPS) framework to run operations on AMD Radeon GPU. When I try to put the graph onto GPU as g.to(“mps”), I get an error below:

anaconda3/envs/pytorch/lib/python3.9/site-packages/dgl/utils/internal.py", line 585, in to_dgl_context
device_type = nd.DGLContext.STR2MASK[F.device_type(ctx)]
KeyError: ‘mps’

I’m able to perform operations on the GPU without DGL using PyTorch 2.2. However, it doesn’t work with DGL 2.0. I’d appreciate it if you could help me run DGL on an AMD GPU using MPS device.

Thank you!

Sorry, we don’t have plan to support mps backend for now due to limited resource. I have created a feature request on github to track it. We’ll discuss about it.

1 Like

Okay, that makes sense! Thank you so much for the clarification. I appreciate you considering it as an additional feature.

1 Like