Incorporting multiple edge weights into message passing

Hi

I’m interested in using the GCN model with additional edge weights defined by a Gaussian kernel as described here in Equation 1 via

z^{l}_{v,p} = \sum_{j \in N_{v}} \sum^{M_{l}}_{q=1}\sum_{k=1}^{K_{l}} w_{p,q,k}^{(l)} \times y_{j,q}^{(l)} \times \phi(\mu_{v}, \mu_{u} ; \hat{\mu}_{k}, \hat{\sigma}_{k}) + b_{p}^{(l)}

where

\phi(\mu_{v}, \mu_{u} ; \hat{\mu}_{k}, \hat{\sigma}_{k}) = exp^{-\hat{\sigma}_{k} ||(\mu_{v} - \mu_{u})-\hat{\mu}_{k}||^{2})}

is the kernel weight between nodes v and u, for kernel k, given the node feature vectors at layer l, and where w_{p,q,k} is the usual weight for the q-th input feature and p-th output feature, but now augmented for the k-th kernel, so that W^{(l)} \in \mathbb{R}^{Q, P, K}

I found this post indicating how to use a single edge weight – however, I’m interested in parameterizing the number of kernels for each layer, so that we are not restricted to one weight per edge, but rather K weights per edge:

with graph.local_scope():

graph.ndata['h'] = h
graph.apply_edges(fn.u_sub_v('h', 'h', 'diff'))
difference = (graph.edata['diff'].unsqueeze(1) - mu)

return -1*sigma*torch.sqrt(torch.sum(difference*difference, dim=1))

where mu (K x P) and sigma (1 x K), so that our edges weights are (|E| x K).

How can I go about do this?

Thank you.

Krisitan

Edit:

I guess it would seem possible to combine the GATConv and GraphConv layers together – whereby, instead of computing attention weights via the heads of the GATConv layer compute, we reformulate the GATConv layer to compute kernel weights instead – my question then is how to combine the kernel weights and GraphConv weights into this process?

So only \hat{\sigma}_k and \hat{\mu}_k are kernel-specific and we have K kernels corresponding to K suites of edge weights, correct? Also torch.sum(difference*difference, dim=1) is for L1-norm. You also need to compute the absolute value. For general Lp norm, you can directly use torch.norm.

Hi

Yes that’s correct – each layer will have K^{(l)} kernels, each kernel with its own \mu_{k}^{(l)} vector and isotropic \sigma_{k}^{(l)}.

Ah I was not aware of the torch.norm function – thanks.

k

I guess you can do something as follows:

import dgl
import dgl.function as fn
import torch
# Create a graph for demo
graph = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 0])))
h = torch.randn(graph.num_nodes(), 4)

# h is of shape (|V|, P), V is the node set and P is the feature size
# mu is of shape (K, P), K is the number of kernels
# sigma is of shape (1, K)
graph.ndata['h'] = h
graph.apply_edges(fn.v_sub_u('h', 'h', 'diff'))
K = 2
mu = torch.randn(K, h.shape[1])
sigma = torch.randn(1, K)
# difference with broadcasting
diff = graph.edata['diff'] - mu.unsqueeze(1) # (K, |E|, P)
diff = torch.norm(diff, p=2, dim=2)          # (K, |E|)
logit = - sigma.transpose(1, 0) * diff       # (K, |E|)
out = logit.exp()                            # (K, |E|)

Hi

Thank you! I ended up doing something similar. Likewise, I want to learn the mu and sigma parameters, such that we update these parameters through backprop. Does the following seem correct?

self.mu = nn.Parameter(
            th.Tensor(1, num_kernels, out_feats), requires_grad=True)
self.sigma = nn.Parameter(
            th.Tensor(num_kernels, 1), requires_grad=True)

nn.init.uniform_(self.sigma, a=0.0, b=1.0, )
nn.init.xavier_uniform_(self.mu)

Yes, they look good to me.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.