Here is the code where I select the values with -1, they end up in “idsToRemove”.
Furthermore, I attempted to use “isolated_nodes” from: How to remove isolated nodes in a DGL graph?, but I still can’t remove the “allow_zero_in_degree=True” in my model definition., for the code to run. Meaning I have two problems, it seems.
def _process_single( self, designPath ):
nodes_data = pd.read_csv( designPath / 'preProcessedGatesToHeat.csv', index_col = 'id' )
nodes_data = nodes_data.sort_index()
edges_data = pd.read_csv( designPath / 'DGLedges.csv')
edges_src = torch.from_numpy( edges_data['Src'].to_numpy() )
edges_dst = torch.from_numpy( edges_data['Dst'].to_numpy() )
df = nodes_data[ listFeats + [ labelName ] ]
df_wanted = np.logical_and ( np.where( df[ rawFeatName ] > 0, True, False ), np.where( df[ labelName ] >= 0, True, False ) )
df_wanted = np.invert( df_wanted )
removedNodesMask = torch.tensor( df_wanted )
idsToRemove = torch.tensor( nodes_data.index )[ removedNodesMask ]
self.graph = dgl.graph( ( edges_src, edges_dst ), num_nodes = nodes_data.shape[0] )
self.graph.ndata[ featName ] = torch.tensor( nodes_data[ listFeats ].values )
self.graph.ndata[ labelName ] = ( torch.from_numpy ( nodes_data[ labelName ].to_numpy() ) )
self.graph.ndata[ secondLabel ] = ( torch.from_numpy ( nodes_data[ secondLabel ].to_numpy() ) )
print("\n---> BEFORE REMOVED NODES:")
print( "\tself.graph.nodes()", self.graph.nodes().shape, "\n", self.graph.nodes() )
print( "\tself.graph.ndata\n", self.graph.ndata )
self.graph.remove_nodes( idsToRemove )
################################################################################################
isolated_nodes = ( ( self.graph.in_degrees() == 0 ) | ( self.graph.out_degrees() == 0 ) ).nonzero().squeeze(1)
print( "isolated_nodes:", isolated_nodes.shape, "\n", isolated_nodes )
self.graph.remove_nodes( isolated_nodes )
################################################################################################
print("\n---> AFTER REMOVED NODES:")
print( "\tself.graph.nodes()", self.graph.nodes().shape, "\n", self.graph.nodes() )
print( "\tself.graph.ndata\n", self.graph.ndata )
return self.graph