Hi Dgl community,
I need some assistance with the following code. Currently, it works when the number of features for customers, merchants, and transactions
are all three. However, when I include additional features only for customers
(such as df[[‘customer_id’, ‘city_id’, ‘lat’,‘long’]]), an error occurs. Could you advise on how to modify the code to resolve this error?
# Define a Heterograph Conv model
class RGCN(nn.Module):
def __init__(self, in_feats, hid_feats, hid_feats2, out_feats, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(in_feats, hid_feats)
for rel in rel_names}, aggregate='sum')
self.conv2 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(hid_feats, hid_feats2)
for rel in rel_names}, aggregate='sum')
self.conv3 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(hid_feats2, out_feats)
for rel in rel_names}, aggregate='sum')
def forward(self, graph, inputs):
# inputs are features of nodes
h = self.conv1(graph, inputs)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv3(graph, h)
return h
#Here is my issue
#I would like to use different number of features such as df[[‘customer_id’, ‘city_id’, ‘lat’,‘long’]]
customer_feat = df[['customer_id', 'city_id', 'lat']]
merchant_feat = df[['merchant_id', 'merch_long', 'merch_lat']]
trans_feat = df[['transaction_id', 'is_fraud','amt']]
raph_data = {
('customer', 'make', 'transaction'): (df.customer_id.to_numpy(), df.transaction_id.to_numpy()),
('merchant', 'create', 'transaction'): (df.merchant_id.to_numpy(), df.transaction_id.to_numpy()),
('transaction', 'make-by', 'customer'): (df.transaction_id.to_numpy(), df.customer_id.to_numpy()),
('transaction', 'create_by', 'merchant'): (df.transaction_id.to_numpy(), df.merchant_id.to_numpy())
}
hetero_graph = dgl.heterograph(graph_data)
print(hetero_graph)
hetero_graph.nodes['customer'].data['feature'] = torch.tensor(customer_feat.groupby(['customer_id']).max().to_numpy())
hetero_graph.nodes['merchant'].data['feature'] = torch.tensor(merchant_feat.groupby(['merchant_id']).max().to_numpy())
hetero_graph.nodes['transaction'].data['feature'] = torch.tensor(trans_feat.groupby(['transaction_id']).max().to_numpy())
hetero_graph.nodes['transaction'].data['label'] = torch.tensor(trans_feat.sort_values(by=['transaction_id']).is_fraud.to_numpy())
#randomly generate training masks on user nodes and click edges
hetero_graph.nodes['transaction'].data['train_mask'] = torch.cat((torch.ones(encoded_df1_up.shape[0], dtype=torch.bool), torch.zeros(df2.shape[0], dtype=torch.bool)), 0)
model = RGCN(2, 20, 10, len(pd.unique(df.is_fraud)), hetero_graph.etypes)
c_feats = hetero_graph.nodes['customer'].data['feature'].float()
m_feats = hetero_graph.nodes['merchant'].data['feature'].float()
t_feats = hetero_graph.nodes['transaction'].data['feature'].float()
labels = hetero_graph.nodes['transaction'].data['label']
train_mask = hetero_graph.nodes['transaction'].data['train_mask']
test_mask = train_mask.logical_not()
node_features = {'customer': c_feats, 'merchant': m_feats, 'transaction': t_feats}
h_dict = model(hetero_graph, {'customer': c_feats, 'merchant': m_feats, 'transaction': t_feats})
h_cus = h_dict['customer']
h_mer = h_dict['merchant']
h_tran = h_dict['transaction']
opt = torch.optim.Adam(model.parameters(), lr=0.001)
count = 0
for epoch in range(100):
model.train()
# forward propagation by using all nodes and extracting the user embeddings
logits = model(hetero_graph, node_features)['transaction']
# compute loss
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
# Compute validation accuracy. Omitted in this example.
# backward propagation
opt.zero_grad()
loss.backward()
opt.step()
count+=1
# if count % 10 == 0:
# print(loss.item())
model.eval()
with torch.no_grad():
pred = model(hetero_graph, node_features)['transaction']
result = hetero_graph.nodes['transaction'].data['label']
train_acc = (pred[test_mask].argmax(1) == labels[test_mask]).float().mean()
test_acc = (pred[test_mask].argmax(1) == labels[test_mask]).float().mean()
print('*'* 50)
print('Loss', loss.item())
print('Test Accuracy: %.2f%%' % (test_acc.item() * 100))
print('Recall', recall_score(labels[test_mask], pred[test_mask].argmax(1)) * 100)
print('F1Score', f1_score(labels[test_mask], pred[test_mask].argmax(1)) * 100)
print('ROC', roc_auc_score(labels[test_mask], pred[test_mask].argmax(1)) * 100)
Current error message