TreeLSTM where nodes are sequences?

Hi!

First thanks for this library :slight_smile:

Second, I am trying to use the TreeLSTM implementation in a dataset were each node is a sequence. Is that possible? I guess that would require 2 LSTM/RNN…one for the tree structure and one for the sequences?

I have implemented a mini example of what I am trying to do

a) small modification to the original LSTM tree (no embeddings, no mask, no batch)

    import torch as th
    import torch.nn as nn
    import dgl

    class TreeLSTMCell(nn.Module):
        def __init__(self, x_size, h_size):
            super(TreeLSTMCell, self).__init__()
            self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
            self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
            self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
            self.U_f = nn.Linear(2 * h_size, 2 * h_size)

        def message_func(self, edges):
            return {'h': edges.src['h'], 'c': edges.src['c']}

        def reduce_func(self, nodes):
            h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
            f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
            c = th.sum(f * nodes.mailbox['c'], 1)
            return {'iou': self.U_iou(h_cat), 'c': c}

        def apply_node_func(self, nodes):
            iou = nodes.data['iou'] + self.b_iou
            i, o, u = th.chunk(iou, 3, 1)
            i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
            c = i * u + nodes.data['c']
            h = o * th.tanh(c)
            return {'h' : h, 'c' : c}

    class ChildSumTreeLSTMCell(nn.Module):
        def __init__(self, x_size, h_size):
            super(ChildSumTreeLSTMCell, self).__init__()
            self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
            self.U_iou = nn.Linear(h_size, 3 * h_size, bias=False)
            self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
            self.U_f = nn.Linear(h_size, h_size)

        def message_func(self, edges):
            return {'h': edges.src['h'], 'c': edges.src['c']}

        def reduce_func(self, nodes):
            h_tild = th.sum(nodes.mailbox['h'], 1)
            f = th.sigmoid(self.U_f(nodes.mailbox['h']))
            c = th.sum(f * nodes.mailbox['c'], 1)
            return {'iou': self.U_iou(h_tild), 'c': c}

        def apply_node_func(self, nodes):
            iou = nodes.data['iou'] + self.b_iou
            i, o, u = th.chunk(iou, 3, 1)
            i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
            c = i * u + nodes.data['c']
            h = o * th.tanh(c)
            return {'h': h, 'c': c}

    class TreeLSTM(nn.Module):
        def __init__(self,
                     num_vocabs,
                     x_size,
                     h_size,
                     num_classes,
                     dropout,
                     cell_type='nary',
                     pretrained_emb=None):
            super(TreeLSTM, self).__init__()
            self.x_size = x_size 
            self.dropout = nn.Dropout(dropout)
            self.linear = nn.Linear(h_size, num_classes)
            cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell
            self.cell = cell(x_size, h_size)

        def forward(self, batch, g, h, c):
            """Compute tree-lstm prediction given a batch.
            Parameters
            ----------
            batch : dgl.data.SSTBatch
                The data batch.
            g : dgl.DGLGraph
                Tree for computation.
            h : Tensor
                Initial hidden state.
            c : Tensor
                Initial cell state.
            Returns
            -------
            logits : Tensor
                The prediction of each node.
            """
            embeds = g.ndata['x'] #just the data
            g.ndata['iou'] = self.cell.W_iou(embeds)
            g.ndata['h'] = h
            g.ndata['c'] = c
            # propagate
            dgl.prop_nodes_topo(g, self.cell.message_func, self.cell.reduce_func, apply_node_func=self.cell.apply_node_func)
            # compute logits
            h = self.dropout(g.ndata.pop('h'))
            logits = self.linear(h)
            return logits

b) Model

    import dgl
    import torch
    from torch.utils.data import DataLoader
    import collections
    from tree_lstm import TreeLSTM
    from torch.utils.data import Dataset
    class MyDataset(Dataset):
        def __init__(self,graphs):
            self.graphs = graphs

        def __getitem__(self, index): #sets a[i]
            graph = self.graphs[index]
            return graph
        def __len__(self):#len of the train loader
            return len(self.graphs)
    def load_graph_3():
        import numpy as np
        from scipy.sparse import coo_matrix
        row = np.array([0, 1,  0]) #edges
        col = np.array([1, 3, 2])
        data = np.array([4, 5, 9]) #edge data/weights
        a = coo_matrix((data, (row, col)), shape=(4, 4)) #that is just the  tree topology
        g = dgl.from_scipy(a)
        g.ndata['x'] = torch.zeros((4, 6,5)) #THE NODES DATA ARE SEQUENCES!!!!
        g.ndata["y"] = torch.Tensor([1, 2, 3, 4])  # labels
        g.ndata["mask"] = torch.ones(4)
        return g

    def main():
        device ="cpu"
        dgl_graph = load_graph_3()
        SSTBatch = collections.namedtuple('SSTBatch', ['graph', "mask",'wordid','label'])


        GraphDataset = MyDataset([dgl_graph])
        def batcher(device):
            def batcher_dev(batch):
                batch_trees = dgl.batch(batch)
                return SSTBatch(graph=batch_trees,
                                mask=batch_trees.ndata['mask'].to(device),
                                wordid=batch_trees.ndata['x'].to(device),
                                label=batch_trees.ndata['y'])
            return batcher_dev
        train_loader = DataLoader(dataset=GraphDataset,
                                      batch_size=1,
                                      collate_fn=batcher(device),
                                      shuffle=False,
                                      num_workers=0)

        model = TreeLSTM(vocab_size,
                         x_size,
                         h_size,
                         num_classes,
                         dropout,
                         cell_type='childsum',
                         pretrained_emb=None).to(device)

        epochs = 2
        for epoch in range(epochs):
            model.train()
            for step, batch in enumerate(train_loader): #TODO: Make work with single tree
                g = batch.graph.to(device)  # all graphs in the batch
                n = g.number_of_nodes()  # number of nodes (leaves and internals)
                h = torch.zeros((n, h_size)).to(device)  # initial hidden state
                c = torch.zeros((n, h_size)).to(device)  
                logits = model(batch, g, h, c)
                print(logits.shape)

    if __name__ == "__main__":
        vocab_size = 21
        x_size = 5 #should be equal to the second dimension of ndata['x']
        h_size = 20
        dropout = 0.1
        num_classes = 21
        main()

So basically, is it possible to have an LSTM on the tree and then also in the sequences in the nodes?
Is there a GRU version?

Thanks in advance :slight_smile:

1 Like

It should be possible but you cannot assign node node to G.ndata as the example did because DGL does not support sequence node data type. One workaround is to have a separate list of list type of data to store node sequences and save a node ID tensor to G.ndata. At each step, use the node ID from edges.src to slice out the corresponding node sequences.

@minjie Thanks for your reply! Hmmm, ok I see. For now, I will try another model were the nodes do not have to be sequences and then reflect on if I want the TreeLSTM with sequences in the nodes, hhehe

I have identifiers (not in the above example though) for my nodes, so that part should be fine. I guess I would have to dig deeper into de message_func() and edge.src because I have not clue , heheh. Thanks again!

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.