From 3b7e9cd02587e15bcb8d4699b5a555da58b4313c Mon Sep 17 00:00:00 2001 From: Alex Morrise Date: Fri, 1 Sep 2023 13:02:33 -0700 Subject: [PATCH 1/9] adds joint model for link and node prediction -- toy working example --- graphistry/compute/gnn_utils.py | 162 ++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 graphistry/compute/gnn_utils.py diff --git a/graphistry/compute/gnn_utils.py b/graphistry/compute/gnn_utils.py new file mode 100644 index 000000000..52a56ce43 --- /dev/null +++ b/graphistry/compute/gnn_utils.py @@ -0,0 +1,162 @@ +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv +from torch_geometric.data import DataLoader +from torch_geometric.datasets import Planetoid +import torch.optim as optim +from collections import defaultdict + +# Base GNN Model +class BaseGNN(nn.Module): + def __init__(self, in_channels, out_channels): + super(BaseGNN, self).__init__() + self.conv1 = GCNConv(in_channels, 128) + self.conv2 = GCNConv(128, out_channels) + + def forward(self, data): + x, edge_index = data.x, data.edge_index + x = self.conv1(x, edge_index) + x = F.relu(x) + x = self.conv2(x, edge_index) + return x + +# Node Prediction Model +class NodePredictionModel(nn.Module): + def __init__(self, base_model, n_classes=10): + super(NodePredictionModel, self).__init__() + self.base_model = base_model + self.classifier = nn.Linear(base_model.conv2.out_channels, n_classes) #classes for node prediction + + def forward(self, data): + x = self.base_model(data) + x = F.relu(x) + x = self.classifier(x) + return F.log_softmax(x, dim=1) + +# Link Prediction Model +class LinkPredictionModel(nn.Module): + def __init__(self, base_model): + super(LinkPredictionModel, self).__init__() + self.base_model = base_model + + def forward(self, data, edge_index_pos, edge_index_neg): + x = self.base_model(data) + x_pos = torch.cat([x[edge_index_pos[0]], x[edge_index_pos[1]]], dim=1) + x_neg = torch.cat([x[edge_index_neg[0]], x[edge_index_neg[1]]], dim=1) + return x_pos, x_neg + +# Joint Model for Node and Link Prediction +class JointModel(nn.Module): + def __init__(self, in_channels): + super(JointModel, self).__init__() + self.base_model = BaseGNN(in_channels, 128) + self.node_model = NodePredictionModel(self.base_model) + self.link_model = LinkPredictionModel(self.base_model) + + def forward(self, data, edge_index_pos, edge_index_neg): + node_pred = self.node_model(data) + link_pred_pos, link_pred_neg = self.link_model(data, edge_index_pos, edge_index_neg) + return node_pred, link_pred_pos, link_pred_neg + + +def joint_loss(node_pred, node_labels, link_pred_pos, link_pred_neg): + node_loss = F.nll_loss(node_pred, node_labels) + + link_labels_pos = torch.ones(link_pred_pos.shape[0]).to(link_pred_pos.device) + link_labels_neg = torch.zeros(link_pred_neg.shape[0]).to(link_pred_neg.device) + + link_loss = F.binary_cross_entropy_with_logits(link_pred_pos, link_labels_pos) + \ + F.binary_cross_entropy_with_logits(link_pred_neg, link_labels_neg) + + return node_loss + link_loss + + +def sample_edges(edge_index, num_samples): + """ + Sample positive and negative edges considering degree dominance. + + Parameters: + edge_index (Tensor): The edge index tensor. + num_samples (int): The number of positive/negative samples needed. + + Returns: + edge_index_pos (Tensor): Tensor for positive samples. + edge_index_neg (Tensor): Tensor for negative samples. + """ + # Create an edge list and a degree dictionary + edge_list = edge_index.t().tolist() + degree_dict = defaultdict(int) + + for u, v in edge_list: + degree_dict[u] += 1 + degree_dict[v] += 1 + + # Sort the edge list based on node degree (sum of degrees of both nodes) + sorted_edge_list = sorted(edge_list, key=lambda x: degree_dict[x[0]] + degree_dict[x[1]]) + + # Split into high-degree and low-degree edges + mid_point = len(sorted_edge_list) // 2 + high_degree_edges = sorted_edge_list[mid_point:] + low_degree_edges = sorted_edge_list[:mid_point] + + # Sample equally from high-degree and low-degree edges for positive samples + positive_samples = random.sample(high_degree_edges, num_samples // 2) + random.sample(low_degree_edges, num_samples // 2) + random.shuffle(positive_samples) + + # Generate negative samples ensuring they are not in the graph + negative_samples = set() + while len(negative_samples) < num_samples: + u, v = random.choice(list(degree_dict.keys())), random.choice(list(degree_dict.keys())) + if u != v and (u, v) not in edge_list and (v, u) not in edge_list: + # Take into account the degree to balance high and low degree nodes + if random.random() < (degree_dict[u] + degree_dict[v]) / (2 * sum(degree_dict.values())): + negative_samples.add((u, v)) + + # Convert to PyTorch tensors + edge_index_pos = torch.tensor(positive_samples, dtype=torch.long).t().contiguous() + edge_index_neg = torch.tensor(list(negative_samples), dtype=torch.long).t().contiguous() + + return edge_index_pos, edge_index_neg + +if __name__ == '__main__': + # use Planetoid's CORA dataset as an example. + dataset = Planetoid(root='/tmp/Cora', name='Cora') + data = dataset[0] + + # Create data loaders (use your own data loaders if you have custom datasets) + train_loader = DataLoader([data], batch_size=32, shuffle=True) + + # Initialize model and optimizer + joint_model = JointModel(dataset.num_features) + optimizer = optim.Adam(joint_model.parameters(), lr=0.003) + + # Split edges for positive and negative samples using degree dominance + edge_index_pos, edge_index_neg = sample_edges(data.edge_index, num_samples=2 * data.num_edges) + + # Training Loop + joint_model.train() + for epoch in range(100): + for batch in train_loader: + optimizer.zero_grad() + + node_pred, link_pred_pos, link_pred_neg = joint_model(batch, edge_index_pos, edge_index_neg) + + loss = joint_loss(node_pred, batch.y, link_pred_pos, link_pred_neg) + + loss.backward() + optimizer.step() + + print(f'Epoch {epoch+1}, Loss: {loss.item()}') + + # Evaluation Loop (simplified) + joint_model.eval() + with torch.no_grad(): + correct = 0 + for batch in train_loader: + node_pred, _, _ = joint_model(batch, edge_index_pos, edge_index_neg) + pred = node_pred.argmax(dim=1) + correct += pred.eq(batch.y).sum().item() + + print(f'Node Classification Accuracy: {correct / len(train_loader.dataset)}') From b3bf045fd12c59a6fb808ba8c5b0fb23c9ac7231 Mon Sep 17 00:00:00 2001 From: Alex Morrise Date: Fri, 1 Sep 2023 13:38:59 -0700 Subject: [PATCH 2/9] better val print out --- graphistry/compute/gnn_utils.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/graphistry/compute/gnn_utils.py b/graphistry/compute/gnn_utils.py index 52a56ce43..1223c0629 100644 --- a/graphistry/compute/gnn_utils.py +++ b/graphistry/compute/gnn_utils.py @@ -6,6 +6,8 @@ from torch_geometric.data import DataLoader from torch_geometric.datasets import Planetoid import torch.optim as optim + +from sklearn.metrics import roc_auc_score from collections import defaultdict # Base GNN Model @@ -127,6 +129,8 @@ def sample_edges(edge_index, num_samples): # Create data loaders (use your own data loaders if you have custom datasets) train_loader = DataLoader([data], batch_size=32, shuffle=True) + val_loader = DataLoader([data], batch_size=32, shuffle=False) + # Initialize model and optimizer joint_model = JointModel(dataset.num_features) @@ -134,6 +138,7 @@ def sample_edges(edge_index, num_samples): # Split edges for positive and negative samples using degree dominance edge_index_pos, edge_index_neg = sample_edges(data.edge_index, num_samples=2 * data.num_edges) + edge_index_pos_val, edge_index_neg_val = sample_edges(data.edge_index, 100) # Training Loop joint_model.train() @@ -154,9 +159,18 @@ def sample_edges(edge_index, num_samples): joint_model.eval() with torch.no_grad(): correct = 0 - for batch in train_loader: - node_pred, _, _ = joint_model(batch, edge_index_pos, edge_index_neg) - pred = node_pred.argmax(dim=1) - correct += pred.eq(batch.y).sum().item() + for batch in val_loader: + node_pred_val, link_pred_pos_val, link_pred_neg_val = joint_model(batch, edge_index_pos_val, edge_index_neg_val) + val_loss = joint_loss(node_pred_val, batch.y, link_pred_pos_val, link_pred_neg_val) + + # Node prediction metrics + pred = node_pred_val.argmax(dim=1) + node_correct = pred.eq(batch.y).sum().item() + node_accuracy = node_correct / len(val_loader.dataset) - print(f'Node Classification Accuracy: {correct / len(train_loader.dataset)}') + # Link prediction metrics + link_labels = torch.cat([torch.ones(link_pred_pos_val.shape[0]), torch.zeros(link_pred_neg_val.shape[0])]) + link_preds = torch.cat([link_pred_pos_val, link_pred_neg_val]) + roc_score = roc_auc_score(link_labels.detach().cpu(), link_preds.detach().cpu()) + + print(f'Epoch {epoch+1}, Validation Loss: {val_loss.item()}, Node Classification Accuracy: {node_accuracy}, Link Prediction ROC: {roc_score}') From 5c69c19683f7412058ee11532fbb3b492ea09f2e Mon Sep 17 00:00:00 2001 From: Alex Morrise Date: Sat, 2 Sep 2023 11:22:45 -0700 Subject: [PATCH 3/9] adds validation hooks, fixes net with linear layer output, does 99% accuracy on nodes, and 67% AUC for link prediction, even in this simple model. For Cody/Tanmoy when I am in France --- graphistry/compute/gnn_utils.py | 164 ++++++++++++++++++-------------- 1 file changed, 91 insertions(+), 73 deletions(-) diff --git a/graphistry/compute/gnn_utils.py b/graphistry/compute/gnn_utils.py index 1223c0629..bb89ecf01 100644 --- a/graphistry/compute/gnn_utils.py +++ b/graphistry/compute/gnn_utils.py @@ -42,11 +42,16 @@ class LinkPredictionModel(nn.Module): def __init__(self, base_model): super(LinkPredictionModel, self).__init__() self.base_model = base_model - + self.link_classifier = nn.Linear(256, 1) + def forward(self, data, edge_index_pos, edge_index_neg): x = self.base_model(data) x_pos = torch.cat([x[edge_index_pos[0]], x[edge_index_pos[1]]], dim=1) x_neg = torch.cat([x[edge_index_neg[0]], x[edge_index_neg[1]]], dim=1) + + x_pos = self.link_classifier(x_pos).squeeze(-1) + x_neg = self.link_classifier(x_neg).squeeze(-1) + return x_pos, x_neg # Joint Model for Node and Link Prediction @@ -75,102 +80,115 @@ def joint_loss(node_pred, node_labels, link_pred_pos, link_pred_neg): return node_loss + link_loss -def sample_edges(edge_index, num_samples): +def sample_edges(edge_index, num_samples, balance_degree=False): """ - Sample positive and negative edges considering degree dominance. - + Sample positive and negative edges from a graph represented by its edge index. + Parameters: edge_index (Tensor): The edge index tensor. - num_samples (int): The number of positive/negative samples needed. - + num_samples (int): Number of negative samples to generate. + balance_degree (bool): Whether to use degree-based sampling (need to add, as first attempt was sloooow) + Returns: - edge_index_pos (Tensor): Tensor for positive samples. - edge_index_neg (Tensor): Tensor for negative samples. + pos_samples (Tensor): Tensor of positive edge samples. + neg_samples (Tensor): Tensor of negative edge samples. """ - # Create an edge list and a degree dictionary - edge_list = edge_index.t().tolist() - degree_dict = defaultdict(int) - - for u, v in edge_list: - degree_dict[u] += 1 - degree_dict[v] += 1 - - # Sort the edge list based on node degree (sum of degrees of both nodes) - sorted_edge_list = sorted(edge_list, key=lambda x: degree_dict[x[0]] + degree_dict[x[1]]) - - # Split into high-degree and low-degree edges - mid_point = len(sorted_edge_list) // 2 - high_degree_edges = sorted_edge_list[mid_point:] - low_degree_edges = sorted_edge_list[:mid_point] - - # Sample equally from high-degree and low-degree edges for positive samples - positive_samples = random.sample(high_degree_edges, num_samples // 2) + random.sample(low_degree_edges, num_samples // 2) - random.shuffle(positive_samples) - - # Generate negative samples ensuring they are not in the graph - negative_samples = set() - while len(negative_samples) < num_samples: - u, v = random.choice(list(degree_dict.keys())), random.choice(list(degree_dict.keys())) - if u != v and (u, v) not in edge_list and (v, u) not in edge_list: - # Take into account the degree to balance high and low degree nodes - if random.random() < (degree_dict[u] + degree_dict[v]) / (2 * sum(degree_dict.values())): - negative_samples.add((u, v)) - - # Convert to PyTorch tensors - edge_index_pos = torch.tensor(positive_samples, dtype=torch.long).t().contiguous() - edge_index_neg = torch.tensor(list(negative_samples), dtype=torch.long).t().contiguous() + print("Step 1: Determine number of nodes in the graph") + num_nodes = edge_index.max().item() + 1 + + print("Step 2: Select positive samples") + # Positive samples + pos_samples = edge_index.t() - return edge_index_pos, edge_index_neg + if pos_samples.shape[0] < num_samples: + raise ValueError(f"Not enough edges in graph to sample {num_samples} positive samples.") + + pos_samples = pos_samples[:num_samples] + + print(f"Positive samples:\n{len(pos_samples)}") + + print("Step 3: Create an adjacency set for fast lookup") + # Create an adjacency set for fast lookup + adjacency_set = set([(u.item(), v.item()) for u, v in pos_samples]) + + print("Step 4: Perform negative sampling") + # Negative sampling + neg_samples = set() + + print("Sampling candidate negative samples...") + while len(neg_samples) < num_samples: + u, v = np.random.randint(0, num_nodes, 2) + if u != v and (u, v) not in adjacency_set and (v, u) not in adjacency_set: + print(f"Adding negative sample: ({u}, {v})") if u+v %10 ==0 else None + neg_samples.add((u, v)) + + neg_samples = torch.tensor(list(neg_samples), dtype=torch.long) + + print(f"Negative samples:\n{len(neg_samples)}") + + return pos_samples.t(), neg_samples.t() if __name__ == '__main__': # use Planetoid's CORA dataset as an example. + + # Check if GPU is available + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Load dataset and create train/validation Data objects dataset = Planetoid(root='/tmp/Cora', name='Cora') - data = dataset[0] + data = dataset[0].to(device) - # Create data loaders (use your own data loaders if you have custom datasets) - train_loader = DataLoader([data], batch_size=32, shuffle=True) - val_loader = DataLoader([data], batch_size=32, shuffle=False) + train_data = data # For this example, same data for training and validation (usually these should be different) + val_data = data + train_loader = DataLoader([train_data], batch_size=32, shuffle=True) + val_loader = DataLoader([val_data], batch_size=32, shuffle=False) - # Initialize model and optimizer - joint_model = JointModel(dataset.num_features) - optimizer = optim.Adam(joint_model.parameters(), lr=0.003) + joint_model = JointModel(dataset.num_features).to(device) + optimizer = optim.Adam(joint_model.parameters(), lr=0.01) - # Split edges for positive and negative samples using degree dominance - edge_index_pos, edge_index_neg = sample_edges(data.edge_index, num_samples=2 * data.num_edges) - edge_index_pos_val, edge_index_neg_val = sample_edges(data.edge_index, 100) + edge_index_pos_train, edge_index_neg_train = sample_edges(train_data.edge_index, train_data.num_edges) + edge_index_pos_val, edge_index_neg_val = sample_edges(val_data.edge_index, 100) + + # Make sure to move sampled edges to the same device as your model + edge_index_pos_train, edge_index_neg_train = edge_index_pos_train.to(device), edge_index_neg_train.to(device) + edge_index_pos_val, edge_index_neg_val = edge_index_pos_val.to(device), edge_index_neg_val.to(device) # Training Loop joint_model.train() for epoch in range(100): for batch in train_loader: optimizer.zero_grad() - - node_pred, link_pred_pos, link_pred_neg = joint_model(batch, edge_index_pos, edge_index_neg) - + + batch = batch.to(device) + + node_pred, link_pred_pos, link_pred_neg = joint_model(batch, edge_index_pos_train, edge_index_neg_train) + loss = joint_loss(node_pred, batch.y, link_pred_pos, link_pred_neg) - + loss.backward() optimizer.step() - - print(f'Epoch {epoch+1}, Loss: {loss.item()}') - - # Evaluation Loop (simplified) - joint_model.eval() - with torch.no_grad(): - correct = 0 - for batch in val_loader: - node_pred_val, link_pred_pos_val, link_pred_neg_val = joint_model(batch, edge_index_pos_val, edge_index_neg_val) - val_loss = joint_loss(node_pred_val, batch.y, link_pred_pos_val, link_pred_neg_val) - - # Node prediction metrics - pred = node_pred_val.argmax(dim=1) - node_correct = pred.eq(batch.y).sum().item() - node_accuracy = node_correct / len(val_loader.dataset) - + + print(f'Epoch {epoch+1}, Train Loss: {loss.item()}') + + # Validation Loop + joint_model.eval() + with torch.no_grad(): + for batch in val_loader: + batch = batch.to(device) + + node_pred_val, link_pred_pos_val, link_pred_neg_val = joint_model(batch, edge_index_pos_val, edge_index_neg_val) + val_loss = joint_loss(node_pred_val, batch.y, link_pred_pos_val, link_pred_neg_val) + + # Node prediction metrics + pred = node_pred_val.argmax(dim=1) + node_correct = pred.eq(batch.y).sum().item() + node_accuracy = node_correct / pred.shape[0] + # Link prediction metrics - link_labels = torch.cat([torch.ones(link_pred_pos_val.shape[0]), torch.zeros(link_pred_neg_val.shape[0])]) + link_labels = torch.cat([torch.ones(link_pred_pos_val.shape[0]), torch.zeros(link_pred_neg_val.shape[0])]).to(device) link_preds = torch.cat([link_pred_pos_val, link_pred_neg_val]) roc_score = roc_auc_score(link_labels.detach().cpu(), link_preds.detach().cpu()) - print(f'Epoch {epoch+1}, Validation Loss: {val_loss.item()}, Node Classification Accuracy: {node_accuracy}, Link Prediction ROC: {roc_score}') + print(f'-- Validation Loss: {val_loss.item()}, Node Classification Accuracy: {node_accuracy}, Link Prediction ROC: {roc_score}') + print() \ No newline at end of file From 81da009eb03fefc82a3a4ee6591fcdd967199b66 Mon Sep 17 00:00:00 2001 From: Alex Morrise Date: Mon, 25 Sep 2023 13:16:04 -0700 Subject: [PATCH 4/9] adds learnable node features --- graphistry/compute/gnn_utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/graphistry/compute/gnn_utils.py b/graphistry/compute/gnn_utils.py index bb89ecf01..a9ed031f8 100644 --- a/graphistry/compute/gnn_utils.py +++ b/graphistry/compute/gnn_utils.py @@ -24,6 +24,23 @@ def forward(self, data): x = self.conv2(x, edge_index) return x +# Base GNN Model with Learnable Node Parameters +class BaseGNNLearnableNodeParams(nn.Module): + def __init__(self, num_nodes, in_channels, out_channels): + super(BaseGNNLearnableNodeParams, self).__init__() + self.num_nodes = num_nodes + self.node_features = nn.Parameter(torch.randn(num_nodes, in_channels)) + self.conv1 = GCNConv(in_channels, 128) + self.conv2 = GCNConv(128, out_channels) + + def forward(self, data): + x = self.node_features + edge_index = data.edge_index + x = self.conv1(x, edge_index) + x = F.relu(x) + x = self.conv2(x, edge_index) + return x + # Node Prediction Model class NodePredictionModel(nn.Module): def __init__(self, base_model, n_classes=10): From 21e2304bcfc6e9c25fffdf9fb9e11ab20fe5cb6c Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 28 Feb 2024 16:41:21 +0800 Subject: [PATCH 5/9] import numpy --- graphistry/compute/gnn_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphistry/compute/gnn_utils.py b/graphistry/compute/gnn_utils.py index a9ed031f8..f454d4383 100644 --- a/graphistry/compute/gnn_utils.py +++ b/graphistry/compute/gnn_utils.py @@ -6,6 +6,7 @@ from torch_geometric.data import DataLoader from torch_geometric.datasets import Planetoid import torch.optim as optim +import numpy as np from sklearn.metrics import roc_auc_score from collections import defaultdict @@ -208,4 +209,4 @@ def sample_edges(edge_index, num_samples, balance_degree=False): roc_score = roc_auc_score(link_labels.detach().cpu(), link_preds.detach().cpu()) print(f'-- Validation Loss: {val_loss.item()}, Node Classification Accuracy: {node_accuracy}, Link Prediction ROC: {roc_score}') - print() \ No newline at end of file + print() From 7f8e1e4d6af10e2630f1bed33560a6d70b5aefd2 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 28 Feb 2024 16:44:37 +0800 Subject: [PATCH 6/9] lint --- graphistry/compute/gnn_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/graphistry/compute/gnn_utils.py b/graphistry/compute/gnn_utils.py index f454d4383..d7dc5780a 100644 --- a/graphistry/compute/gnn_utils.py +++ b/graphistry/compute/gnn_utils.py @@ -47,7 +47,7 @@ class NodePredictionModel(nn.Module): def __init__(self, base_model, n_classes=10): super(NodePredictionModel, self).__init__() self.base_model = base_model - self.classifier = nn.Linear(base_model.conv2.out_channels, n_classes) #classes for node prediction + self.classifier = nn.Linear(base_model.conv2.out_channels, n_classes) # classes for node prediction def forward(self, data): x = self.base_model(data) @@ -92,8 +92,7 @@ def joint_loss(node_pred, node_labels, link_pred_pos, link_pred_neg): link_labels_pos = torch.ones(link_pred_pos.shape[0]).to(link_pred_pos.device) link_labels_neg = torch.zeros(link_pred_neg.shape[0]).to(link_pred_neg.device) - link_loss = F.binary_cross_entropy_with_logits(link_pred_pos, link_labels_pos) + \ - F.binary_cross_entropy_with_logits(link_pred_neg, link_labels_neg) + link_loss = F.binary_cross_entropy_with_logits(link_pred_pos, link_labels_pos) + F.binary_cross_entropy_with_logits(link_pred_neg, link_labels_neg) return node_loss + link_loss @@ -137,7 +136,7 @@ def sample_edges(edge_index, num_samples, balance_degree=False): while len(neg_samples) < num_samples: u, v = np.random.randint(0, num_nodes, 2) if u != v and (u, v) not in adjacency_set and (v, u) not in adjacency_set: - print(f"Adding negative sample: ({u}, {v})") if u+v %10 ==0 else None + print(f"Adding negative sample: ({u}, {v})") if u + v %10 == 0 else None neg_samples.add((u, v)) neg_samples = torch.tensor(list(neg_samples), dtype=torch.long) @@ -146,6 +145,7 @@ def sample_edges(edge_index, num_samples, balance_degree=False): return pos_samples.t(), neg_samples.t() + if __name__ == '__main__': # use Planetoid's CORA dataset as an example. From f9d3adae2c3bcb2a15940540fc11cbd182d5434f Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 28 Feb 2024 16:49:29 +0800 Subject: [PATCH 7/9] lint --- graphistry/compute/gnn_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphistry/compute/gnn_utils.py b/graphistry/compute/gnn_utils.py index d7dc5780a..834516513 100644 --- a/graphistry/compute/gnn_utils.py +++ b/graphistry/compute/gnn_utils.py @@ -136,7 +136,7 @@ def sample_edges(edge_index, num_samples, balance_degree=False): while len(neg_samples) < num_samples: u, v = np.random.randint(0, num_nodes, 2) if u != v and (u, v) not in adjacency_set and (v, u) not in adjacency_set: - print(f"Adding negative sample: ({u}, {v})") if u + v %10 == 0 else None + print(f"Adding negative sample: ({u}, {v})") if u + v % 10 == 0 else None neg_samples.add((u, v)) neg_samples = torch.tensor(list(neg_samples), dtype=torch.long) From 0d8cc5d8f8a0642f03c238d9b887dadd72c9e972 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 28 Feb 2024 16:55:40 +0800 Subject: [PATCH 8/9] add torch_geometric to setup --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c81db1b09..976456cc8 100755 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def unique_flatten_dict(d): 'umap-learn': ['umap-learn', 'dirty-cat==0.2.0', 'scikit-learn>=1.0'], } # https://github.com/facebookresearch/faiss/issues/1589 for faiss-cpu 1.6.1, #'setuptools==67.4.0' removed -base_extras_heavy['ai'] = base_extras_heavy['umap-learn'] + ['scipy', 'dgl', 'torch<2', 'sentence-transformers', 'faiss-cpu', 'joblib'] +base_extras_heavy['ai'] = base_extras_heavy['umap-learn'] + ['scipy', 'dgl', 'torch<2', 'torch-geometric', 'sentence-transformers', 'faiss-cpu', 'joblib'] base_extras = {**base_extras_light, **base_extras_heavy} From 4d057cdf3856804019a1c819eb0e3ec3d1d7c592 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 28 Feb 2024 17:22:46 +0800 Subject: [PATCH 9/9] add torch_geometric to mypy.ini --- mypy.ini | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mypy.ini b/mypy.ini index 898e00114..c1a9c228e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -83,6 +83,9 @@ ignore_missing_imports = True [mypy-torch.*] ignore_missing_imports = True +[mypy-torch_geometric.*] +ignore_missing_imports = True + [mypy-umap.*] ignore_missing_imports = True