Mixed precision training on heterogenous graph

import dgl
import dgl.nn as dglnn
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h contains the node representations for each node type computed from
        # the GNN defined in the previous section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']
def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()

use_fp16 = True
g = dgl.heterograph({
   ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),
   ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))})
g.nodes['drug'].data['hv'] = th.rand(3, 10)
g.nodes['disease'].data['hv'] = th.rand(3, 10)

k = 5
model = Model(10, 20, 5, g.etypes)
user_feats = g.nodes['drug'].data['hv']
item_feats = g.nodes['disease'].data['hv']
node_features = {'drug': user_feats, 'disease': item_feats}
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()

device = torch.device(0)
model = model.to(device)

for epoch in range(10):
    r = ("drug", "treats", "disease")
    g = g.to("cpu")
    negative_graph = construct_negative_graph(g, k, r)
    g = g.to(device)
    negative_graph = negative_graph.to(device)
    node_features = {k: v.to(device) for k, v in node_features.items()}
    with autocast():
        pos_score, neg_score = model(g, negative_graph, node_features, r)
        loss = compute_loss(pos_score, neg_score)

    if use_fp16:
        # Backprop w/ gradient scaling

Most of the code snippets are copied from the document. If I run this with use_fp16 = True, I get

AttributeError                            Traceback (most recent call last)
Input In [12], in <cell line: 19>()
     31 optimizer.zero_grad()
     32 if use_fp16:
     33     # Backprop w/ gradient scaling
---> 34     scaler.scale(loss).backward()
     35     scaler.step(optimizer)
     36     scaler.update()

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/torch/_tensor.py:307, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    298 if has_torch_function_unary(self):
    299     return handle_torch_function(
    300         Tensor.backward,
    301         (self,),
    305         create_graph=create_graph,
    306         inputs=inputs)
--> 307 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/torch/autograd/__init__.py:154, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    151 if retain_graph is None:
    152     retain_graph = create_graph
--> 154 Variable._execution_engine.run_backward(
    155     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    156     allow_unreachable=True, accumulate_grad=True)

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/torch/autograd/function.py:199, in BackwardCFunction.apply(self, *args)
    195     raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom "
    196                        "Function is not allowed. You should only implement one "
    197                        "of them.")
    198 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 199 return user_fn(self, *args)

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/torch/cuda/amp/autocast_mode.py:111, in custom_bwd.<locals>.decorate_bwd(*args, **kwargs)
    108 @functools.wraps(bwd)
    109 def decorate_bwd(*args, **kwargs):
    110     with autocast(args[0]._fwd_used_autocast):
--> 111         return bwd(*args, **kwargs)

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/backend/pytorch/sparse.py:341, in GSDDMM.backward(ctx, dZ)
    339             dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * Y)
    340         else:  # rhs_target = !lhs_target
--> 341             dX = gspmm(_gidx, 'mul', 'sum', Y, dZ)
    342 else:  # lhs_target == 'e'
    343     if op in ['add', 'copy_lhs']:

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/backend/pytorch/sparse.py:757, in gspmm(gidx, op, reduce_op, lhs_data, rhs_data)
    755     op = 'mul'
    756     rhs_data = 1. / rhs_data
--> 757 return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/torch/cuda/amp/autocast_mode.py:94, in custom_fwd.<locals>.decorate_fwd(*args, **kwargs)
     92         return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
     93 else:
---> 94     return fwd(*args, **kwargs)

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/backend/pytorch/sparse.py:126, in GSpMM.forward(ctx, gidx, op, reduce_op, X, Y)
    123 @staticmethod
    124 @custom_fwd(cast_inputs=th.float16)
    125 def forward(ctx, gidx, op, reduce_op, X, Y):
--> 126     out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
    127     reduce_last = _need_reduce_last_dim(X, Y)
    128     X_shape = X.shape if X is not None else None

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/sparse.py:194, in _gspmm(gidx, op, reduce_op, u, e)
    192 use_e = op != 'copy_lhs'
    193 if use_u and use_e:
--> 194     if F.dtype(u) != F.dtype(e):
    195         raise DGLError("The node features' data type {} doesn't match edge"
    196                        " features' data type {}, please convert them to the"
    197                        " same type.".format(F.dtype(u), F.dtype(e)))
    198 # deal with scalar features.

File ~/miniconda3/envs/gnn-p39/lib/python3.9/site-packages/dgl-0.9-py3.9-linux-x86_64.egg/dgl/backend/pytorch/tensor.py:76, in dtype(input)
     75 def dtype(input):
---> 76     return input.dtype

AttributeError: 'NoneType' object has no attribute 'dtype'

Seems like None value issue in a place where tensor is expected. Any idea on fixing/working around this?

DGL version: 0.9
Pytorch version: 1.10.2

You have not used the autocast API in pytorch, so the custom_fwd decorator has no effect here.

You encounter this issue because the edge data and node data have different data type, you can either check their data type or wrap the forward stage with autocast(enabled=True).

Could you elaborate? Are you saying calling forward wrapped with autocast is not enough?

The issue seems to be in compute_loss function, not heterogenous graph. It runs fine if I change this to normal edge classifier (something like HeteroMLPPredictor here).

It may be a bug related to the u_dot_v operator. What if you change your compute_loss function to follows?

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        u, v = graph.edges(etype=etype)
        return (h[u] * h[v]).sum(1)