Am I missing something about fn.copy_u?

Hi everyone,

I am trying to check that copy_u does what I want, but it doesn’t seem to be the case. I’d like to sum the features of the 5 nearest neighbours to a given node. However, although the edges seem to be correct, when I sum up the features I don’t get the same answer as I do it within a for loop (as can be seen in my dirty_graph function). Should I be using a different function rather than copy_u?

import numpy as np                                                                                                   
import dgl                                                                                                           
import dgl.function as fn                                                                                            
from scipy.spatial import cKDTree                                                                                    
def generate_fake_data(n_particles):                                                                                 
   pos  = np.random.random((n_particles,3)).astype(np.float32)                                                      
   mass = np.random.normal(loc=13, scale = 3, size=n_particles).astype(np.float32)                                  
   return pos , mass                                                                                                
def build_test_graph_distance(positions,                                                                             
       distance = 1.,                                                                                               
       n_neighbors = 5                                                                                              
   tree = cKDTree(positions,)                                                                                       
   distances, edges = tree.query(positions,                                                                         
   edges = [(idx, dest) for idx, destination in enumerate(edges) for dest in destination \                          
                                           if ((idx!=dest)&(dest<len(positions)))]                                  
   src, dst = zip(*edges)                                                                                           
   G = dgl.DGLGraph()                                                                                               
   G.add_edges(src, dst)                                                                                            
   G.ndata["masses"] = features                                                                                     
   G.update_all(fn.copy_u('masses', 'm'), fn.sum('m', 'sum_mass'))                                                  
   labels = G.ndata['sum_mass']                                                                                     
   return G, labels                                                                                       
def dirty_graph(positions,                                                                                           
   labels = []                                                                                                      
   for i in range(len(positions)):                                                                                  
       difference = positions - positions[i]                                                                        
       distance = np.linalg.norm(difference, axis=-1)                                                               
       idx = np.argsort(distance)[1:n_neighbors+1]                                                                  
   return labels                                                                                                    
if __name__=='__main__':                                                                                             
   pos, mass = generate_fake_data(10)                                                                               
   graph, labels= build_test_graph_distance(pos, mass, )                                                            
   dirty_labels = dirty_graph(pos, mass)                                                                            
   np.testing.assert_almost_equal(labels.numpy(), dirty_labels, decimal=1)

Hi, I just tried your code and cannot reproduce the warning. Did you check the values? Is it possible that warnings are due to numerical issues?

Interesting. I get the following output

Mismatch: 80%
Max absolute difference: 61.832886
Max relative difference: 0.91820663
 x: array([ 70.4,  58. ,  52.2,  75. ,  29.1,  62.9, 129.2,  62.2, 125.5,
        40.4], dtype=float32)
 y: array([70.4, 73.4, 65.9, 75. , 71.6, 59.5, 67.3, 74.1, 69.9, 67.7],

Wierd. Can you update the code snippet and include the code used to print the output above?

It’s the same code snippet. This bit in the main,

is doing the printing when they are not equal. If they are equal it shouldn’t print anything, which I assume it is what happens to you? Can you please try using more samples? (That means changing the value of 10 in generate_fake_data)

Ok. I know what’s going on here.

edges = [(idx, dest) for idx, destination in enumerate(edges) for dest in destination if ((idx!=dest)&(dest<len(positions)))]

should be changed to

edges = [(src, idx) for idx, sources in enumerate(edges) for src in sources if ((idx != src) & (src<len(positions)))]

Since we want to sum the distances with respect to the nearest neighbors, the nearest neighbors should be source nodes of messages rather than destination nodes.

1 Like

That’s great! I completely misunderstood the meaning of source and destination. Thank you :smiley: