I ran sample code on the heterogeneous graph with RGCN following the tutorial here. However, when inspecting the gradient flow I found some parameters (with require_grad=True) cannot be backward because their gradients are None. May you explain why this is happening?
The code to reproduce the issue:
import numpy as np
import dgl
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
n_users = 1000
n_items = 500
n_follows = 3000
n_clicks = 5000
n_dislikes = 500
n_hetero_features = 10
n_user_classes = 5
n_max_clicks = 10
follow_src = np.random.randint(0, n_users, n_follows)
follow_dst = np.random.randint(0, n_users, n_follows)
click_src = np.random.randint(0, n_users, n_clicks)
click_dst = np.random.randint(0, n_items, n_clicks)
dislike_src = np.random.randint(0, n_users, n_dislikes)
dislike_dst = np.random.randint(0, n_items, n_dislikes)
hetero_graph = dgl.heterograph({
('user', 'follow', 'user'): (follow_src, follow_dst),
('user', 'followed-by', 'user'): (follow_dst, follow_src),
('user', 'click', 'item'): (click_src, click_dst),
('item', 'clicked-by', 'user'): (click_dst, click_src),
('user', 'dislike', 'item'): (dislike_src, dislike_dst),
('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})
hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)
hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)
hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))
hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()
# randomly generate training masks on user nodes and click edges
hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)
hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)
def check_grad(named_parameters):
ave_grads = []
max_grads= []
layers = []
for n, p in named_parameters:
if p.requires_grad and "bias" not in n:
if p.grad is None:
print("-", n)
continue
else:
print("+", n)
layers.append(n)
ave_grads.append(p.grad.abs().mean())
max_grads.append(p.grad.abs().max())
class RGCN(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(in_feats, hid_feats)
for rel in rel_names}, aggregate='sum')
self.conv2 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(hid_feats, out_feats)
for rel in rel_names}, aggregate='sum')
def forward(self, graph, inputs):
# inputs are features of nodes
h = self.conv1(graph, inputs)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h)
return h
model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
labels = hetero_graph.nodes['user'].data['label']
train_mask = hetero_graph.nodes['user'].data['train_mask']
node_features = {'user': user_feats, 'item': item_feats}
h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})
h_user = h_dict['user']
h_item = h_dict['item']
opt = torch.optim.Adam(model.parameters())
for epoch in range(5):
model.train()
# forward propagation by using all nodes and extracting the user embeddings
logits = model(hetero_graph, node_features)['user']
# compute loss
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
# Compute validation accuracy. Omitted in this example.
# backward propagation
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
check_grad(model.named_parameters())
break
which output as follow (indicating that conv2.mods.click.weight
and conv2.mods.dislike.weight
cannot be backward.
1.8087488412857056
+ conv1.mods.clicked-by.weight
+ conv1.mods.disliked-by.weight
+ conv1.mods.click.weight
+ conv1.mods.dislike.weight
+ conv1.mods.follow.weight
+ conv1.mods.followed-by.weight
+ conv2.mods.clicked-by.weight
+ conv2.mods.disliked-by.weight
- conv2.mods.click.weight
- conv2.mods.dislike.weight
+ conv2.mods.follow.weight
+ conv2.mods.followed-by.weight