Hi!
First thanks for this library
Second, I am trying to use the TreeLSTM implementation in a dataset were each node is a sequence. Is that possible? I guess that would require 2 LSTM/RNN…one for the tree structure and one for the sequences?
I have implemented a mini example of what I am trying to do
a) small modification to the original LSTM tree (no embeddings, no mask, no batch)
import torch as th
import torch.nn as nn
import dgl
class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes):
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c}
def apply_node_func(self, nodes):
iou = nodes.data['iou'] + self.b_iou
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c']
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(ChildSumTreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
self.U_f = nn.Linear(h_size, h_size)
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes):
h_tild = th.sum(nodes.mailbox['h'], 1)
f = th.sigmoid(self.U_f(nodes.mailbox['h']))
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_tild), 'c': c}
def apply_node_func(self, nodes):
iou = nodes.data['iou'] + self.b_iou
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c']
h = o * th.tanh(c)
return {'h': h, 'c': c}
class TreeLSTM(nn.Module):
def __init__(self,
num_vocabs,
x_size,
h_size,
num_classes,
dropout,
cell_type='nary',
pretrained_emb=None):
super(TreeLSTM, self).__init__()
self.x_size = x_size
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes)
cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell
self.cell = cell(x_size, h_size)
def forward(self, batch, g, h, c):
"""Compute tree-lstm prediction given a batch.
Parameters
----------
batch : dgl.data.SSTBatch
The data batch.
g : dgl.DGLGraph
Tree for computation.
h : Tensor
Initial hidden state.
c : Tensor
Initial cell state.
Returns
-------
logits : Tensor
The prediction of each node.
"""
embeds = g.ndata['x'] #just the data
g.ndata['iou'] = self.cell.W_iou(embeds)
g.ndata['h'] = h
g.ndata['c'] = c
# propagate
dgl.prop_nodes_topo(g, self.cell.message_func, self.cell.reduce_func, apply_node_func=self.cell.apply_node_func)
# compute logits
h = self.dropout(g.ndata.pop('h'))
logits = self.linear(h)
return logits
b) Model
import dgl
import torch
from torch.utils.data import DataLoader
import collections
from tree_lstm import TreeLSTM
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self,graphs):
self.graphs = graphs
def __getitem__(self, index): #sets a[i]
graph = self.graphs[index]
return graph
def __len__(self):#len of the train loader
return len(self.graphs)
def load_graph_3():
import numpy as np
from scipy.sparse import coo_matrix
row = np.array([0, 1, 0]) #edges
col = np.array([1, 3, 2])
data = np.array([4, 5, 9]) #edge data/weights
a = coo_matrix((data, (row, col)), shape=(4, 4)) #that is just the tree topology
g = dgl.from_scipy(a)
g.ndata['x'] = torch.zeros((4, 6,5)) #THE NODES DATA ARE SEQUENCES!!!!
g.ndata["y"] = torch.Tensor([1, 2, 3, 4]) # labels
g.ndata["mask"] = torch.ones(4)
return g
def main():
device ="cpu"
dgl_graph = load_graph_3()
SSTBatch = collections.namedtuple('SSTBatch', ['graph', "mask",'wordid','label'])
GraphDataset = MyDataset([dgl_graph])
def batcher(device):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].to(device),
wordid=batch_trees.ndata['x'].to(device),
label=batch_trees.ndata['y'])
return batcher_dev
train_loader = DataLoader(dataset=GraphDataset,
batch_size=1,
collate_fn=batcher(device),
shuffle=False,
num_workers=0)
model = TreeLSTM(vocab_size,
x_size,
h_size,
num_classes,
dropout,
cell_type='childsum',
pretrained_emb=None).to(device)
epochs = 2
for epoch in range(epochs):
model.train()
for step, batch in enumerate(train_loader): #TODO: Make work with single tree
g = batch.graph.to(device) # all graphs in the batch
n = g.number_of_nodes() # number of nodes (leaves and internals)
h = torch.zeros((n, h_size)).to(device) # initial hidden state
c = torch.zeros((n, h_size)).to(device)
logits = model(batch, g, h, c)
print(logits.shape)
if __name__ == "__main__":
vocab_size = 21
x_size = 5 #should be equal to the second dimension of ndata['x']
h_size = 20
dropout = 0.1
num_classes = 21
main()
So basically, is it possible to have an LSTM on the tree and then also in the sequences in the nodes?
Is there a GRU version?
Thanks in advance