Tensorflow @tf.function wrappers problem

I’m following the minigc tutorial, and ported the “Classifier” class to the tensorflow (attached below).

I can successfully run the training loop code(not displayed) when I run the code in the “eager mode”. However, when I insert the @tf.function decorator to the call() function, I am receiving errors.

My guess is because the call function inputs a non-tensor, dgl.graph, it has to be executed in eager mode. I have looked at DGL examples for tensorflow and noticed that @tf.function is not used there as well.

Is it possible to use DGL NNs without the eager mode? Because using @tf.function is recommended for performance.

import tensorflow as tf
from tensorflow.keras import layers

import dgl
from dgl.nn.tensorflow.conv import GraphConv

class Classifier(layers.Layer):

def __init__(self, in_dim, hidden_size, num_classes):
    
    super(Classifier, self).__init__()
        
    self.conv1 = GraphConv(in_dim, hidden_size)
    self.conv2 = GraphConv(hidden_size, hidden_size)
    
    self.logits = layers.Dense(num_classes)

@tf.function
def call(self, data_in):
    
    '''
    
    In:
        data_in:          #tuple containing dgl graph and features
    Inter:            
        inputs:     #V,1
    '''

    g = data_in[0][0]
    feat = data_in[0][1]
    inputs = tf.cast(feat,tf.float32)
    inputs = tf.expand_dims(inputs,1)
    
    h = self.conv1(g, inputs)
    h = tf.keras.activations.relu(h)

    h = self.conv2(g, h)
    h = tf.keras.activations.relu(h)
    
    g.ndata['h'] = h
    # Calculate graph representation by averaging all the node representations.
    hg = dgl.mean_nodes(g, 'h')
    
    return self.logits(hg)

Hi,

This is not supported now. DGL currently is designed for the imperative execution. @tf.function uses symbolic executions. Also DGLGraph is a non-tensor object. We haven’t figured out a good solution so, do you have any idea on this?

Hello,

Thank you. I have identified a graph NN library specifically written for tensorflow called spectral that uses a tf.sparse_tensor to represent the graph. Therefore, it is compatible with the @tf.function.