Hi guys, in my last post Deep walk Indexing error. Confusing Behavior - #5 by Zihan I posted an instance about random_walk
returns -1 in DeepWalk’s batch_walk, which caused the index error in DeepWalk’s nn.Embedding.
To resolve this problem, Rhett recommended me to give pack_trace
a try.
And I tried following impls:
class DisConnectedDeepWalk(DeepWalk):
def sample(self, indices):
return dgl.sampling.random_walk(self.g, indices, length=self.walk_length-1)
...
for epoch in range(1):
for batch_walk_tuple in dataloader:
traces, types = batch_walk_tuple
print("traces.shape: ", traces.shape)
print("types.shape: ", types.shape)
batch_walk, _ , _ , _= pack_traces(traces, types)
print("batch_walk.shape: ", batch_walk.shape)
resphaped_batch_walk = batch_walk.view(1, -1) # since you cannot determine the batch size, like the packed_trace can be 300,400 or so, where you can not fix batch size like 128 by batch_walk.view(128, -1)
print("resphaped_batch_walk.shape: ", resphaped_batch_walk.shape)
loss = model(resphaped_batch_walk)
optimizer.zero_grad()
loss.backward()
optimizer.step()
However, pack_trace
will pack the traces together,
Give a minimal example:
[[1,30,40,-1,-1],
[2,32,-1,-1,-1]]
The packed sequence [1,30,40,2,32] cannot be used for training directly, because it’s not a real walk. Here, 1 and 2 are both starting points of walks, if you pack them together you are using a wrong walk.
The actual problem to solve now is that the length of each walk is different. It seems that pack_trace
doesn’t solve this problem.
The simplest and most brute-force method that can get it running is to set pack_trace & batch_size = 1, which is to train the embeddings one sequence at a time. This approach can support the requirement of walks of inconsistent lengths, but the quality of the embeddings trained with batch_size = 1
is not ideal.
A straightforward idea is to read all the walks, sort them by walk length, and then divide them into batches. In each batch, the walk lengths are consistent. However, this approach is clearly infeasible for large graphs.
Any advices? Thanks ahead of time