Confusion about the aggregate operation in implementation of LGNN

Hello, I have some confusion about the “aggregate operation” implementation in LGNN tutorial and recommended implementation, in the original paper, the part of the original paper that involves aggregate operation is \sum^{J-1}_{j=0}(A^{2^{j}}x^{k})_{i}, and the implementation of dgl is as follows:

def aggregate(self, g, z):
        z_list = []
        g.ndata['z'] = z
        g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
        z_list.append(g.ndata['z'])
        for i in range(self.radius - 1):
            for j in range(2 ** i):
                g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
            z_list.append(g.ndata['z'])
        return z_list

for radius=3, it seems that above implementation only performs 4 message passing(1 out the loop, 3 in the loop).However, we can get A, A^2, A^4 according to the original paper for J = 3(J is equivalent to radius), and we need to perform 7(1+2+4) message passing. Should it be changed to the following?

def aggregate(self, g, z):
        z_list = []
        g.ndata['z'] = z
        for i in range(self.radius):
            for j in range(2 ** i):
                g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
            z_list.append(g.ndata['z'])
        return z_list

Is there something wrong with my understanding?
Thanks in advance!

I think for radius=3, the code snippet performs 1 message passing outside of the loop, 2 ^ 0 + 2 ^ 1 + 2 ^ 2 inside the loop. Note that for getting A^4X, we don’t need to start from scratch as this is simply A^2(A^2X). Meanwhile, I’m not sure if the original implementation is correct either.

Thank you for reply! For radius=3, I think the loop will become:

for i range(3-1):
    for j in range(2 ** i):
        g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
    z_list.append(g.ndata['z'])

And if we set a counter to record the number of the message passing, like this:

cnt = 0
for i range(3-1):
    for j in range(2 ** i):
        # g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
        cnt = cnt + 1
    # z_list.append(g.ndata['z'])
print(cnt)

If we run this code snippet, we can get the output is 3. Which means it actually performs 3 message passing. Is there something wrong?

Yeah, in that case this should be for i in range(3).

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.