-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGraphGdel_main.py
More file actions
484 lines (417 loc) · 21.3 KB
/
GraphGdel_main.py
File metadata and controls
484 lines (417 loc) · 21.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
import os
import argparse
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import csv
import random
import time
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected, negative_sampling
from utils.read_gene_sequences import read_gene_sequences
from utils.read_metabolite_smiles import read_metabolite_smiles
from utils.read_gene_metabolite_relationships import read_gene_metabolite_relationships
from utils.create_gene_metabolite_dataset import create_gene_metabolite_dataset
from utils.build_metabolite_graph import load_graph_data as load_graph_data_from_builder
from utils.evaluations import *
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from model.GeneLSTM import GeneLSTM
from model.MetaLSTM import MetaLSTM
from model.GCN import GCN
from model.CombinedModel import CombinedModel
###################################################
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(description="Run Gene Deletion Strategy Prediction with GraphGdel")
parser.add_argument('--CBM', type=str, default='e_coli_core',
help="The CBM model to use (e.g., iML1515, iMM904, e_coli_core)")
parser.add_argument('--use_cpu', type=int, choices=[0, 1], default=0,
help="Set to 1 to force CPU, 0 to use CUDA if available")
parser.add_argument('--split_seed', type=int, default=None,
help="Random seed for train/val split (for t-test collection)")
parser.add_argument('--output_metrics_csv', type=str, default=None,
help="If set, append one row (CBM, method, split_seed, accuracy, macro_precision, macro_recall, macro_f1, macro_auc) to this CSV")
parser.add_argument('--train', type=int, choices=[0, 1], default=0,
help="Set to 1 to run pre-training and save checkpoint to train_disc (architecture matches current graph)")
parser.add_argument('--train_epochs', type=int, default=50,
help="Number of epochs for pre-training when --train 1")
return parser.parse_args()
def main():
# Parse arguments
args = parse_args()
CBM = args.CBM
use_cpu = args.use_cpu
print(f"Using CBM: {CBM}")
# Set the device
if use_cpu == 1:
device = torch.device("cpu")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Set the bound for fixed_gene exclusion
bound = 0.9
# Data partitioning: repeated stratified 10-fold CV on full dataset; 5 repeats x 10 folds (seeds 1001..1050)
N_FOLDS = 10
FIRST_SPLIT_SEED = 1001 # seeds 1001..1050 = 5 repeats x 10 folds
# Hyperparameters
vocab_embedding_dim = 32 # The size of the embedding vector for each token word
hidden_dim = 64 # The number of units in the LSTM and GCN hidden layer as output
gcn_hidden_channels = 64
batch_size = 8
epochs = 1
train_epochs = args.train_epochs
# Set the random seed for reproducibility
random_seed = 1001
random.seed(random_seed) # For Python's built-in random module
np.random.seed(random_seed) # For numpy-based randomness
# Get the directory where the script is located
script_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in globals() else os.getcwd()
# Define paths relative to the script's location
gene_folder_path = os.path.join(script_dir, 'Data', CBM, 'Gene_AA')
metabolite_smiles_path = os.path.join(script_dir, 'Data', CBM, 'Target metabolites_smile.csv')
relationship_folder_path = os.path.join(script_dir, 'Data', CBM, 'Label')
graph_path = os.path.join(script_dir, 'Data', CBM, 'SMILES_node_feature_final.csv')
output_csv_path = os.path.join(script_dir, 'Data', CBM, 'Results', 'all_metabolites_predictions_temp.csv')
train_disc_dir = os.path.join(script_dir, 'train_disc')
checkpoint_path = os.path.join(train_disc_dir, f'new_code_model_{CBM}.sav')
###################################################
# Reading data (exactly like source code)
print("Loading data...")
gene_names, gene_sequences = read_gene_sequences(gene_folder_path)
metabolite_names, smiles_features = read_metabolite_smiles(metabolite_smiles_path)
relationships = read_gene_metabolite_relationships(relationship_folder_path)
# Tokenize data exactly like source code
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
# Tokenizing gene sequences (exactly like source code)
gene_tokenizer = Tokenizer(char_level=True)
gene_tokenizer.fit_on_texts(gene_sequences)
gene_vocab_size = len(gene_tokenizer.word_index) + 1
max_len = max(len(seq) for seq in gene_sequences)
gene_sequences = pad_sequences(gene_tokenizer.texts_to_sequences(gene_sequences), maxlen=max_len)
# Tokenizing SMILES sequences (exactly like source code)
smiles_tokenizer = Tokenizer(char_level=True)
smiles_tokenizer.fit_on_texts(smiles_features)
smiles_vocab_size = len(smiles_tokenizer.word_index) + 1
smiles_max_len = max(len(smiles) for smiles in smiles_features)
smiles_features = pad_sequences(smiles_tokenizer.texts_to_sequences(smiles_features), maxlen=smiles_max_len)
# Extract unique metabolites (needed for graph builder when JSON is used)
unique_metabolites = list(relationships.keys())
data_dir = os.path.join(script_dir, 'Data')
graph_data, num_nodes, metabolite_names_order = load_graph_data_from_builder(
graph_path, CBM, data_dir=data_dir,
unique_metabolites=unique_metabolites,
top_k=10,
nodes_to_restore=["pyr_c"],
)
graph_data = graph_data.to(device)
if metabolite_names_order is not None:
graph_node_index_map = {name: i for i, name in enumerate(metabolite_names_order)}
print("Graph loaded from CBM JSON (manuscript pipeline).")
else:
graph_node_index_map = None
print("Graph: fallback random (no CBM JSON found).")
# Full dataset used for repeated stratified 10-fold CV (no separate test set)
train_val_pool = list(unique_metabolites)
# Stratification label per metabolite (majority deletion class) for 10-fold stratified CV
metabolite_labels = np.array([
1 if np.mean(relationships[meta][1]) >= 0.5 else 0
for meta in train_val_pool
])
# split_seed: 1001..1050 -> 5 repeats x 10 folds; default 1001 = first fold
split_seed = getattr(args, 'split_seed', None)
if split_seed is None:
split_seed = FIRST_SPLIT_SEED
args.split_seed = split_seed
repeat_idx = (split_seed - FIRST_SPLIT_SEED) // N_FOLDS
fold_idx = (split_seed - FIRST_SPLIT_SEED) % N_FOLDS
kf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=repeat_idx)
X_idx = np.arange(len(train_val_pool))
folds = list(kf.split(X_idx, metabolite_labels))
train_idx, val_idx = folds[fold_idx]
train_metabolites = [train_val_pool[i] for i in train_idx]
val_metabolites = [train_val_pool[i] for i in val_idx]
# Print which metabolites are used for training and validation (this fold)
print("Metabolites used for training (this fold):")
print(train_metabolites[:5], "..." if len(train_metabolites) > 5 else "")
print("\nMetabolites used for validation (this fold):")
print(val_metabolites[:5], "..." if len(val_metabolites) > 5 else "")
# Initialize counters for gene relationships (for all data)
gene_count_all = {}
gene_total_all = {}
# Calculate gene deletion statistics for all data (train + val)
for meta_name in unique_metabolites: # Loop through all metabolites, not just train
gene_names, labels = relationships[meta_name]
for gene_name, label in zip(gene_names, labels):
if gene_name not in gene_count_all:
gene_count_all[gene_name] = 0
gene_total_all[gene_name] = 0
gene_count_all[gene_name] += label
gene_total_all[gene_name] += 1
# Calculate deletion ratios
gene_deletion_ratios = {}
for gene_name in gene_total_all:
if gene_total_all[gene_name] > 0:
gene_deletion_ratios[gene_name] = gene_count_all[gene_name] / gene_total_all[gene_name]
# Filter genes based on deletion ratio
filtered_genes = [gene for gene, ratio in gene_deletion_ratios.items() if ratio <= bound]
print(f"Number of genes after filtering: {len(filtered_genes)}")
# Create genes_to_exclude set (same as notebook)
genes_to_exclude = set(gene_names) - set(filtered_genes)
print(f"Genes to exclude during evaluation: {len(genes_to_exclude)}")
# Create datasets
print("Creating datasets...")
# Custom collate function for PyTorch Geometric Data objects
def custom_collate(batch):
gene_seqs = torch.stack([item[0] for item in batch])
smiles_seqs = torch.stack([item[1] for item in batch])
labels = torch.stack([item[2] for item in batch])
node_indices = torch.stack([item[3] for item in batch])
# For graph data, we'll use the first one since they're all the same
graph = batch[0][4]
return gene_seqs, smiles_seqs, labels, node_indices, graph
# Vocabulary sizes already calculated above (exactly like source code)
# Create datasets (data already tokenized like source code)
train_dataset = create_gene_metabolite_dataset(
train_metabolites, relationships, gene_sequences, smiles_features,
filtered_genes, graph_data, device,
gene_names=gene_names,
metabolite_names=metabolite_names,
graph_node_index_map=graph_node_index_map,
)
val_dataset = create_gene_metabolite_dataset(
val_metabolites, relationships, gene_sequences, smiles_features,
filtered_genes, graph_data, device,
gene_names=gene_names,
metabolite_names=metabolite_names,
graph_node_index_map=graph_node_index_map,
)
# Custom collate function for PyTorch Geometric Data objects
def custom_collate(batch):
gene_seqs = torch.stack([item[0] for item in batch])
smiles_seqs = torch.stack([item[1] for item in batch])
labels = torch.stack([item[2] for item in batch])
node_indices = torch.stack([item[3] for item in batch])
# For graph data, we'll use the first one since they're all the same
graph = batch[0][4]
return gene_seqs, smiles_seqs, labels, node_indices, graph
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
# Initialize models
print("Initializing models...")
gene_lstm = GeneLSTM(vocab_size=gene_vocab_size,
vocab_embedding_dim=vocab_embedding_dim,
hidden_dim=hidden_dim)
smiles_lstm = MetaLSTM(vocab_size=smiles_vocab_size,
vocab_embedding_dim=vocab_embedding_dim,
hidden_dim=hidden_dim)
gcn_model = GCN(in_channels=graph_data.x.size(-1), # Use actual graph feature dimension
hidden_channels=gcn_hidden_channels,
out_channels=hidden_dim,
num_nodes=num_nodes)
model = CombinedModel(gene_lstm, smiles_lstm, gcn_model, hidden_dim)
model = model.to(device)
# Load pretrained model (or run pre-training to produce a matching checkpoint)
if args.train:
print(f"Pre-training for {train_epochs} epochs and saving checkpoint to {checkpoint_path}")
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for epoch in range(train_epochs):
running_loss = 0.0
for batch in train_loader:
gene_seqs, smiles_seqs, labels, node_indices, graph = batch
gene_seqs = gene_seqs.to(device)
smiles_seqs = smiles_seqs.to(device)
labels = labels.to(device).unsqueeze(1)
node_indices = node_indices.to(device)
graph = graph.to(device)
optimizer.zero_grad()
out, _, _ = model(gene_seqs, smiles_seqs, graph, node_indices)
loss = criterion(out, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if (epoch + 1) % 10 == 0 or epoch == 0:
print(f" Epoch {epoch+1}/{train_epochs} loss: {running_loss/len(train_loader):.4f}")
os.makedirs(train_disc_dir, exist_ok=True)
torch.save(model.state_dict(), checkpoint_path)
print(f"Saved checkpoint to {checkpoint_path} (matches current graph: {num_nodes} nodes).")
else:
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded pretrained model {checkpoint_path}")
except Exception as e:
print(f"Warning: Could not load pretrained model due to architecture mismatch: {e}")
print("Running evaluation with randomly initialized weights (demonstration mode)")
print("To create a matching checkpoint, run with --train 1 (e.g. quick_run with train mode).")
# Evaluation using the same method as source code
print("Evaluating model...")
model.eval()
# Copy exact evaluation function from notebook
def find_best_threshold(y_true, y_probs):
thresholds = np.linspace(0.01, 0.99, 99)
best_threshold = 0.5
best_f1 = 0.0
for thresh in thresholds:
y_pred = (y_probs >= thresh).astype(int)
f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
if f1 > best_f1:
best_f1 = f1
best_threshold = thresh
return best_threshold, best_f1
def evaluate_per_metabolite(
model,
metabolite_names,
gene_names,
gene_sequences,
smiles_features,
relationships,
graph,
device,
genes_to_exclude=set(),
gene_name_to_index=None,
threshold_search=True,
val_metabolites=None,
graph_node_index_map=None,
):
model.eval()
graph = graph.to(device)
if graph_node_index_map is not None:
metabolite_to_node_idx = graph_node_index_map
else:
metabolite_to_node_idx = {met: idx for idx, met in enumerate(metabolite_names)}
# Use exact same data processing as source code (no tokenization)
gene_seq_dict = {
gene: torch.tensor(gene_sequences[gene_name_to_index[gene]], dtype=torch.float32, device=device).unsqueeze(0)
for gene in gene_names if gene in gene_name_to_index
}
smiles_dict = {
met: torch.tensor(smiles_features[idx], dtype=torch.float32, device=device).unsqueeze(0)
for idx, met in enumerate(metabolite_names)
}
per_metabolite_info = {}
for metabolite in metabolite_names:
if metabolite not in relationships:
continue
if metabolite not in metabolite_to_node_idx:
continue
gene_list, label_list = relationships[metabolite]
filtered = [(g, l) for g, l in zip(gene_list, label_list)
if g not in genes_to_exclude and g in gene_seq_dict]
if len(filtered) == 0:
continue
genes_filtered, labels_filtered = zip(*filtered)
labels_np = np.array(labels_filtered)
node_idx = metabolite_to_node_idx[metabolite]
node_idx_tensor = torch.tensor([node_idx], dtype=torch.long, device=device)
preds_probs = []
for gene in genes_filtered:
gene_seq = gene_seq_dict[gene]
smiles_seq = smiles_dict[metabolite]
with torch.no_grad():
output, _, _ = model(gene_seq, smiles_seq, graph, node_idx_tensor)
pred_prob = output.squeeze().cpu().item()
preds_probs.append(pred_prob)
preds_probs_np = np.array(preds_probs)
if threshold_search:
best_thresh, best_f1 = find_best_threshold(labels_np, preds_probs_np)
else:
best_thresh = 0.5
best_f1 = f1_score(labels_np, (preds_probs_np >= best_thresh).astype(int), average='macro', zero_division=0)
pred_labels = (preds_probs_np >= best_thresh).astype(int)
per_metabolite_info[metabolite] = {
'true': labels_np,
'pred_probs': preds_probs_np,
'pred_labels': pred_labels,
'genes': genes_filtered,
'best_threshold': best_thresh,
'best_f1': best_f1
}
print(f"Metabolite: {metabolite}")
print(f" Number of Non-fixed Genes: {len(labels_np)}")
print(f" Best Threshold: {best_thresh:.3f} with F1: {best_f1:.3f}")
print(f" Accuracy: {accuracy_score(labels_np, pred_labels)*100:.2f}%")
print(f" Precision: {precision_score(labels_np, pred_labels, average='macro', zero_division=0)*100:.2f}%")
print(f" Recall: {recall_score(labels_np, pred_labels, average='macro', zero_division=0)*100:.2f}%")
print(f" F1 Score: {best_f1*100:.2f}%")
print(f" True labels: {np.array(labels_np[:10])}")
print(f" Pred probabilities: {np.round(preds_probs_np[:10], 2)}")
print(f" Pred labels: {np.array(pred_labels[:10])}\n")
# --- Summary of metrics across metabolites ---
acc_list = []
prec_list = []
rec_list = []
f1_list = []
for metabolite, info in per_metabolite_info.items():
if val_metabolites is not None and metabolite not in val_metabolites:
continue
labels_np = info['true']
pred_labels = info['pred_labels']
acc_list.append(accuracy_score(labels_np, pred_labels)*100)
prec_list.append(precision_score(labels_np, pred_labels, average='macro', zero_division=0)*100)
rec_list.append(recall_score(labels_np, pred_labels, average='macro', zero_division=0)*100)
f1_list.append(f1_score(labels_np, pred_labels, average='macro', zero_division=0)*100)
print("=== Overall Performance Summary ===")
mean_acc = np.mean(acc_list) if acc_list else 0.0
mean_prec = np.mean(prec_list) if prec_list else 0.0
mean_rec = np.mean(rec_list) if rec_list else 0.0
mean_f1 = np.mean(f1_list) if f1_list else 0.0
print(f"Overall Accuracy: {mean_acc:.2f}%")
print(f"Macro-Averaged Precision: {mean_prec:.2f}%")
print(f"Macro-Averaged Recall: {mean_rec:.2f}%")
print(f"Macro-Averaged F1 Score: {mean_f1:.2f}%")
all_true, all_probs = [], []
for metabolite, info in per_metabolite_info.items():
if val_metabolites is not None and metabolite not in val_metabolites:
continue
all_true.extend(info['true'].tolist())
all_probs.extend(info['pred_probs'].tolist())
try:
from sklearn.metrics import roc_auc_score
mean_auc = roc_auc_score(all_true, all_probs) * 100 if all_true else 0.0
except Exception:
mean_auc = 0.0
return (mean_acc, mean_prec, mean_rec, mean_f1, mean_auc)
# Run the exact notebook evaluation
val_for_csv = val_metabolites if getattr(args, 'output_metrics_csv', None) else None
summary = evaluate_per_metabolite(
model=model,
metabolite_names=metabolite_names,
gene_names=gene_names,
gene_sequences=gene_sequences,
smiles_features=smiles_features,
relationships=relationships,
graph=graph_data,
device=device,
genes_to_exclude=set(genes_to_exclude),
gene_name_to_index={gene: idx for idx, gene in enumerate(gene_names)},
val_metabolites=val_for_csv,
graph_node_index_map=graph_node_index_map,
)
if args.output_metrics_csv and summary is not None:
import csv
write_header = not os.path.exists(args.output_metrics_csv)
with open(args.output_metrics_csv, 'a', newline='') as f:
w = csv.writer(f)
if write_header:
w.writerow(['CBM', 'method', 'split_seed', 'accuracy', 'macro_precision', 'macro_recall', 'macro_f1', 'macro_auc'])
w.writerow([CBM, 'Proposed', args.split_seed, round(summary[0], 4), round(summary[1], 4), round(summary[2], 4), round(summary[3], 4), round(summary[4], 4)])
# Save results
print(f"Saving results to {output_csv_path}")
# Implementation for saving results would go here
print("GraphGdel training and evaluation completed!")
if __name__ == "__main__":
main()