Hi, thanks for your reply.
Here is how I do in my project
class GNNModule(nn.Module):
def __init__(self, in_feats, out_feats, radius):
super().__init__()
self.out_feats = out_feats
self.radius = radius
new_linear = lambda: nn.Linear(in_feats, out_feats)
new_linear_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])
self.theta_x, self.theta_deg, self.theta_y = \
new_linear(), new_linear(), new_linear()
self.theta_list = new_linear_list()
self.gamma_y, self.gamma_deg, self.gamma_x = \
new_linear(), new_linear(), new_linear()
self.gamma_list = new_linear_list()
self.bn_x = nn.BatchNorm1d(out_feats)
self.bn_y = nn.BatchNorm1d(out_feats)
def aggregate(self, g, z):
z_list = []
g.set_n_repr({'z' : z})
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.get_n_repr()['z'])
for i in range(self.radius - 1):
for j in range(2 ** i):
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.get_n_repr()['z'])
return z_list
def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd):
pmpd_x = F.embedding(pm_pd, x)
sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x)))
g.set_e_repr({'y' : y})
g.update_all(fn.copy_edge(edge='y', out='m'), fn.sum('m', 'pmpd_y'))
pmpd_y = g.pop_n_repr('pmpd_y')
x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y)
n = self.out_feats // 2
x = torch.cat([x[:, :n], F.relu(x[:, n:])], 1)
x = self.bn_x(x)
sum_y = sum(gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y)))
y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum_y + self.gamma_x(pmpd_x)
y = torch.cat([y[:, :n], F.relu(y[:, n:])], 1)
y = self.bn_y(y)
return x, y
And I define the LGNN model
feats = [1] + [nfeat] * n_layers + [nfeat*2]
self.module_list = nn.ModuleList([GNNModule(m, n, radius)
for m, n in zip(feats[:-1], feats[1:])])
When I need to call the model, ( my input data is numpy matrix (bs * N * N), called obs[:, 0, :, :]
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
nx_graphs = [nx.from_numpy_matrix(obs.cpu().numpy()[i, 0, :, :]) for i in range(obs.size(0))]
n_nodes = nx_graphs[0].number_of_nodes()
gs = [DGLGraph(nx_graph) for nx_graph in nx_graphs]
lgs = [g.line_graph(backtracking=False) for g in gs]
in_degrees = lambda g: g.in_degrees(
Index(torch.arange(0, g.number_of_nodes()))).unsqueeze(1).float()
g_degs = [in_degrees(g) for g in gs]
lg_degs = [in_degrees(lg) for lg in lgs]
pm_pds = list(zip(*[g.edges() for g in gs]))[0]
g_batch = batch(gs)
lg_batch = batch(lgs)
degg_batch = torch.cat(g_degs, dim=0).to(device)
deglg_batch = torch.cat(lg_degs, dim=0).to(device)
pm_pd_batch = torch.cat([x + i * n_nodes for i, x in enumerate(pm_pds)], dim=0).to(device)
x, y = degg_batch, deglg_batch
for module in self.module_list:
x, y = module(g_batch, lg_batch, x, y, degg_batch, deglg_batch, pm_pd_batch)
x = torch.reshape(x, (obs.size(0), obs.size(2), -1)) # Bs * N * D
I think the caculation of pm_pd matrix can be put in the forward function of the GNNModule class. The forward function only needs to receive the graphs as the input (g), and the calculations of lg, x, y, deg_g, deg_lg, pm_pd can be capsuled in the forward function, so the interface is quite user-friendly.
The drawback is the LGNN model (30 layers) is very slow in this way, compared to GAT model. I do not know this is because we need to construct lg, pm_pd every time, or just because LGNN is slow. If the speed is not a problem, it would be very nice.
Thanks for your reply again.
Your sincerely.
Shanchao Yang