[Help][pinSage] How to get the best embedding of all the item nodes?

HI team,
I have got the embedding of all item nodes in this way . Is that OK ? I will build the embedding recall system later.
The code i save all the embediding for all the item nodes:
" # Evaluate
model.eval()
with torch.no_grad():
item_batches = torch.arange(g.num_nodes(item_ntype)).split(args.batch_size)

        h_item_batches = []
        for blocks in tqdm.tqdm(dataloader_test):
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to(device)

            h_item_batches.append(model.get_repr(blocks, item_emb))
        h_item = torch.cat(h_item_batches, 0)

        # torch.save(h_item, "embeddings.pth" + "_" + str(epoch_id))

        h_pd = pd.DataFrame(h_item.cpu().numpy())
        h_pd.to_csv("embeddings" + "_" + str(epoch_id) + ".csv")

        print(evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size))"

My Question is that:
I Saw that in the “dataloader_test” dataloader, before get_repr() , the sampler still works here.

" def collate_test(self, samples):
batch = torch.LongTensor(samples)
blocks = self.sampler.sample_blocks(batch)
assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)
return blocks "

Is that necessary to stop taking sample or sum all the linked node’s embedding for the get_repr() ?

Is there any other simple way of getting all the item‘s Embeding for the downstream scenario ?

Wish your help ~
Thanks ~ ~

Hi! I can’t speak to the best way to do it, but here’s how I’m doing it in my code:

 #LOAD GRAPH 
    g, train_g, val_g = build_dataset(cfg)
#LOAD MODEL 
    model = build_model(g, cfg)
    model_state = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(model_state['model_state'])
#DATALOADER
    if mode == 'fullg': 
        neighbor_sampler, collator = build_graph_sampler(g, cfg)  
    elif mode == 'traing': 
        neighbor_sampler, collator = build_graph_sampler(train_g, cfg)
    else: 
        print("ERROR: No graph mode detected - please specify either fullg/traing")
        exit() 
    dataloader_test = DataLoader(
            torch.arange(g.number_of_nodes(cfg.item_type)),
            batch_size=32,
            collate_fn=collator.collate_test,
            num_workers=1)

    model = model.eval()
    model = model.cuda()
    device = torch.device('cuda:0')

    #GENERATE REPRESENTATIONS 
    all_features_cpu = [] 
    all_ids = [] 
    dataloader_it = iter(dataloader_test)
    idx = 0 
    for blocks in tqdm(dataloader_it):
        idx +=1 
        with torch.no_grad():
            for i in range(len(blocks)): 
                blocks[i] = blocks[i].to(device)
            features = model.get_repr(blocks)
            all_features_cpu.append(features.cpu().numpy()) 
            
    
    #SAVE
    if output_path:
        if not os.path.exists(output_path): 
            os.mkdir(output_path)
        file_path = os.path.join(output_path,'embeddings_as_array_{}.pkl'.format(mode))
        print("***Saving Track Embeddings to {}***".format(file_path))
        pickle.dump(all_features_array, open(file_path, "wb"))
    
       
    return all_features_cpu

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.