Building a High-Recall Toxicity Screener

A journal of pivots: Solving data desyncs, disregarding theoretical metrics, and identifying structural blind spots in Graph Neural Networks.

Production Scan: Distributed Ray Serve Inference

00 // The Problem: The Cost of a Miss

In pharmaceutical auditing, a False Positive is an annoyance (a safe drug is flagged for review), but a False Negative is a disaster (a toxic drug is marked safe). Standard machine learning models optimize for overall accuracy, often sacrificing safety by missing rare but deadly edge cases.

The Engineering Objective: Architect a "High-Recall" Graph Neural Network (GNN) that aggressively flags potential toxins. The goal was to push Recall to >0.95 to ensure zero-tolerance for missed toxins, while integrating "White-Box" explainability (Atomic Heatmaps) to justify every flag to human auditors.

01 // The Mathematical Core

The screener utilizes neighborhood aggregation to update atomic feature vectors based on local chemical environments:

\[x_i^{(l+1)} = \sigma \left( \Theta^\top \sum_{j \in \mathcal{N}(i) \cup \{i\}} \frac{1}{\sqrt{\hat{d}_j \hat{d}_i}} x_j^{(l)} \right)\]

Spectral normalization ensures stable updates across different molecular sizes:

\[X^{(l+1)} = \sigma \left( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} X^{(l)} W^{(l)} \right)\]

Final fixed-size representation is achieved by concatenating Global Mean and Max pooling to capture both general structural properties and localized "hotspots":

\[X_{graph} = \left[ \frac{1}{|V|} \sum_{i \in V} x_i^{(L)} \parallel \max_{i \in V} x_i^{(L)} \right]\]

Mathematical Legend

Node-Level (Eq 1)

  • x_i^{(l)} : Feature vector of atom i at layer l.
  • Οƒ : Non-linear activation function (ReLU).
  • Θ : Learnable filter parameters (weights).
  • 𝒩(i) : Neighbors of atom i (including self-loop).
  • dΜ‚ : Degree of node (used for normalization).

Matrix & Graph (Eq 2 & 3)

  • Γƒ : Adjacency matrix with self-loops (A + I).
  • DΜƒ : Diagonal degree matrix of Γƒ.
  • V : Set of all atoms in the molecule.
  • || : Concatenation (joining Mean & Max pools).
  • X_{graph} : Final fixed-length molecular embedding.

02 // Data Acquisition & Integrity

Establishing a reproducible pipeline began with automated data acquisition. I implemented a script to download the official Tox21 dataset and verify the raw CSV structures before processing.

00_download_csv.py
import requests import os import gzip import shutil # 1. Setup Paths url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz" output_dir = "./data/raw" final_path = os.path.join(output_dir, "tox21.csv") compressed_path = final_path + ".gz" os.makedirs(output_dir, exist_ok=True) # 2. Download print(f"⬇️ Downloading Tox21 dataset from {url}...") response = requests.get(url, stream=True) if response.status_code == 200: with open(compressed_path, 'wb') as f: f.write(response.content) print(" Download complete.") else: print(f"❌ Failed to download. Status code: {response.status_code}") exit() # 3. Extract print("πŸ“¦ Extracting CSV...") try: with gzip.open(compressed_path, 'rb') as f_in: with open(final_path, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) print(f"βœ… Success! File saved to: {final_path}") # Optional: Clean up the .gz file os.remove(compressed_path) except Exception as e: print(f"❌ Error during extraction: {e}")
Download Output

03 // Solving the Metadata Desync

Standard GNN pipelines often lose track of molecular SMILES strings during batching. I re-engineered the ingestion pipeline to attach metadata directly to the graph objects. This ensures every predictionβ€”success or failureβ€”is traceable.

01_ingest_csv.py
import pandas as pd import torch import numpy as np from rdkit import Chem from torch_geometric.data import Data import os # --- CONFIGURATION --- RAW_PATH = "./data/raw/tox21.csv" PROCESSED_PATH = "./data/processed/tox21_graphs.pt" # --- HELPER FUNCTIONS --- def one_hot_encoding(x, permitted_list): if x not in permitted_list: x = permitted_list[-1] binary_encoding = [int(x == possible_value) for possible_value in permitted_list] return binary_encoding def get_atom_features(atom): # 9 Features per atom permitted_atoms = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'I', 'Unknown'] atom_type = one_hot_encoding(atom.GetSymbol(), permitted_atoms) return torch.tensor(atom_type, dtype=torch.float) def smiles_to_graph(smiles, label, original_index): """Converts a SMILES string to a PyTorch Geometric Graph with Metadata.""" mol = Chem.MolFromSmiles(smiles) if not mol: return None # 1. Node Features features = [] for atom in mol.GetAtoms(): features.append(get_atom_features(atom)) x = torch.stack(features) # 2. Edges (Connectivity) edge_indices = [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() edge_indices.append([i, j]) edge_indices.append([j, i]) # Undirected if not edge_indices: return None # Skip single atoms edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous() # 3. Label y = torch.tensor([float(label)], dtype=torch.float) # 4. Create Data Object data = Data(x=x, edge_index=edge_index, y=y) # --- 🚨 THE FIX: ATTACH METADATA DIRECTLY TO THE GRAPH --- data.smiles = smiles data.original_index = original_index return data def run_ingest(): print("πŸš€ Starting Ingestion Pipeline...") if not os.path.exists(RAW_PATH): print(f"❌ Error: {RAW_PATH} not found. Run fix_csv.py first!") return df = pd.read_csv(RAW_PATH) print(f"πŸ“„ Loaded CSV with {len(df)} rows.") # Select the specific task (e.g., NR-AhR) # If the row is NaN for this task, we skip it target_task = 'NR-AhR' data_list = [] skipped = 0 for index, row in df.iterrows(): smiles = row['smiles'] label = row[target_task] # Skip if missing label or bad smile if pd.isna(label) or pd.isna(smiles): skipped += 1 continue graph = smiles_to_graph(smiles, label, index) if graph: data_list.append(graph) else: skipped += 1 if index % 1000 == 0: print(f" Processed {index} rows...") # Save os.makedirs(os.path.dirname(PROCESSED_PATH), exist_ok=True) torch.save(data_list, PROCESSED_PATH) print(f"βœ… Saved {len(data_list)} graphs to {PROCESSED_PATH}") print(f"πŸ—‘οΈ Skipped {skipped} invalid/empty rows.") if __name__ == "__main__": run_ingest()

04 // 36-Model Grid Search

I executed a sweep across 36 unique configurations to identify the optimal baseline for Hidden Channels and initial weights.

02_grid_search.py
import torch import torch.nn.functional as F from torch_geometric.loader import DataLoader from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool from torch.nn import Linear from sklearn.metrics import roc_auc_score import mlflow import itertools import numpy as np import os import warnings # --- 🚨 SILENCE WARNINGS --- # 1. Suppresses the "torch-scatter not found" and WinError messages from Python warnings.filterwarnings("ignore", category=UserWarning) # 2. Ensures child processes (like DataLoader workers) also ignore warnings os.environ['PYTHONWARNINGS'] = 'ignore' # 3. Optional: Specifically for Windows Ray Serve issues if you encounter them later os.environ["TORCH_GEOMETRIC_OFFLINE"] = "1" # --- CONFIGURATION --- EPOCHS_PER_RUN = 30 # Keep this reasonable so the sweep finishes today BATCH_SIZE = 64 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # --- THE GRID --- # We will test EVERY combination of these PARAM_GRID = { "lr": [0.01, 0.001, 0.0001], "hidden_channels": [32, 64, 128, 256], "pos_weight": [10.0, 20.0, 30.0] } # --- MODEL DEFINITION (Flexible) --- class ToxicityGCN(torch.nn.Module): def __init__(self, num_features, hidden_channels): super(ToxicityGCN, self).__init__() self.conv1 = GCNConv(num_features, hidden_channels) self.conv2 = GCNConv(hidden_channels, hidden_channels) self.conv3 = GCNConv(hidden_channels, hidden_channels) self.lin = Linear(hidden_channels * 2, 1) def forward(self, x, edge_index, batch): x = x.float() x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() x = self.conv3(x, edge_index) x_mean = global_mean_pool(x, batch) x_max = global_max_pool(x, batch) x = torch.cat([x_mean, x_max], dim=1) return self.lin(x) def train_one_configuration(lr, hidden, weight, train_loader, test_loader): """Trains one specific model configuration and returns the Best AUC.""" run_name = f"LR={lr}_Hidden={hidden}_W={weight}" print(f"\nπŸ§ͺ Starting Run: {run_name}") with mlflow.start_run(run_name=run_name): # 1. Log Params mlflow.log_param("lr", lr) mlflow.log_param("hidden_channels", hidden) mlflow.log_param("pos_weight", weight) # 2. Setup model = ToxicityGCN(num_features=9, hidden_channels=hidden).to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=lr) pos_weight_tensor = torch.tensor([weight]).to(DEVICE) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor) best_run_auc = 0.0 # 3. Train Loop for epoch in range(1, EPOCHS_PER_RUN + 1): model.train() for data in train_loader: data = data.to(DEVICE) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch) loss = criterion(out, data.y.view(-1, 1)) loss.backward() optimizer.step() # 4. Test Loop model.eval() y_true, y_pred = [], [] with torch.no_grad(): for data in test_loader: data = data.to(DEVICE) out = model(data.x, data.edge_index, data.batch) y_true.extend(data.y.cpu().numpy()) y_pred.extend(torch.sigmoid(out).cpu().numpy()) try: auc = roc_auc_score(y_true, y_pred) except: auc = 0.5 # Log Metrics mlflow.log_metric("auc", auc, step=epoch) # Track Best if auc > best_run_auc: best_run_auc = auc print(f" 🏁 Finished {run_name} | Best AUC: {best_run_auc:.4f}") return best_run_auc def run_grid_search(): print("πŸ“‚ Loading Data...") dataset_list = torch.load("./data/processed/tox21_graphs.pt", weights_only=False) dataset_list = [d for d in dataset_list if d.num_nodes > 0] # Split split = int(len(dataset_list) * 0.8) train_loader = DataLoader(dataset_list[:split], batch_size=BATCH_SIZE, shuffle=True) test_loader = DataLoader(dataset_list[split:], batch_size=BATCH_SIZE) # --- GENERATE COMBINATIONS --- keys, values = zip(*PARAM_GRID.items()) combinations = [dict(zip(keys, v)) for v in itertools.product(*values)] print(f"πŸš€ Launching Grid Search over {len(combinations)} configurations...") mlflow.set_experiment("Tox21_Grid_Search") best_overall_auc = 0.0 best_config = {} for i, config in enumerate(combinations): print(f"--- Progress: {i+1}/{len(combinations)} ---") auc = train_one_configuration( lr=config['lr'], hidden=config['hidden_channels'], weight=config['pos_weight'], train_loader=train_loader, test_loader=test_loader ) if auc > best_overall_auc: best_overall_auc = auc best_config = config print("\n" + "="*50) print(f"πŸ† GRID SEARCH COMPLETE") print(f"πŸ₯‡ Best AUC: {best_overall_auc:.4f}") print(f"πŸ₯‡ Best Params: {best_config}") print("="*50) if __name__ == "__main__": run_grid_search()
Grid Search Output

05 // Training: Precision through Optimization

Initially, a learning rate of \(0.01\) caused the model to oscillate, requiring a massive pos_weight to catch toxic signals. By slowing the learning rate to \(0.001\), the model achieved smoother convergenceβ€”correctly flagging toxic Dinitrotoluene while identifying Aspirin as safe using a standard \(10.0\) weight.

03_train_mlflow.py
import torch import torch.nn.functional as F from torch_geometric.loader import DataLoader from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool from torch.nn import Linear from sklearn.metrics import roc_auc_score import mlflow import os # --- CONFIGURATION --- BATCH_SIZE = 64 HIDDEN_CHANNELS = 64 LEARNING_RATE = 0.001 EPOCHS = 30 POS_WEIGHT_FIXED = 10.0 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # --- MODEL DEFINITION --- class ToxicityGCN(torch.nn.Module): def __init__(self, num_features): super(ToxicityGCN, self).__init__() self.conv1 = GCNConv(num_features, HIDDEN_CHANNELS) self.conv2 = GCNConv(HIDDEN_CHANNELS, HIDDEN_CHANNELS) self.conv3 = GCNConv(HIDDEN_CHANNELS, HIDDEN_CHANNELS) self.lin = Linear(HIDDEN_CHANNELS * 2, 1) def forward(self, x, edge_index, batch): x = x.float() # Fix: Ensure float precision x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() x = self.conv3(x, edge_index) x_mean = global_mean_pool(x, batch) x_max = global_max_pool(x, batch) x = torch.cat([x_mean, x_max], dim=1) return self.lin(x) # --- TRAINING PIPELINE --- def run_training(): print("πŸ“‚ Loading Data...") # Load with security check disabled for local file dataset_list = torch.load("./data/processed/tox21_graphs.pt", weights_only=False) # Fix: Remove Ghost Molecules (0 nodes) dataset_list = [d for d in dataset_list if d.num_nodes > 0] # Split split = int(len(dataset_list) * 0.8) train_loader = DataLoader(dataset_list[:split], batch_size=BATCH_SIZE, shuffle=True) test_loader = DataLoader(dataset_list[split:], batch_size=BATCH_SIZE) # Setup Model model = ToxicityGCN(num_features=9).to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) # Weighted Loss Calculation pos_weight = torch.tensor([POS_WEIGHT_FIXED]).to(DEVICE) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) # --- MLFLOW START --- mlflow.set_experiment("Tox21_GCN_Experiment") with mlflow.start_run(): print("πŸš€ Starting Run with MLflow Tracking...") # 1. Log Hyperparams mlflow.log_param("hidden_channels", HIDDEN_CHANNELS) mlflow.log_param("lr", LEARNING_RATE) mlflow.log_param("batch_size", BATCH_SIZE) mlflow.log_param("optimizer", "Adam") # Track Best Performance best_auc = 0.0 best_model_path = "./data/tox_model_best.pth" for epoch in range(1, EPOCHS + 1): model.train() total_loss = 0 for data in train_loader: data = data.to(DEVICE) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch) loss = criterion(out, data.y.view(-1, 1)) loss.backward() optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(train_loader) # Test Step model.eval() y_true, y_pred = [], [] with torch.no_grad(): for data in test_loader: data = data.to(DEVICE) out = model(data.x, data.edge_index, data.batch) y_true.extend(data.y.cpu().numpy()) y_pred.extend(torch.sigmoid(out).cpu().numpy()) try: test_auc = roc_auc_score(y_true, y_pred) except: test_auc = 0.5 print(f"Epoch {epoch:02d} | Loss: {avg_loss:.4f} | AUC: {test_auc:.4f}") # 2. Log Metrics per Epoch mlflow.log_metric("loss", avg_loss, step=epoch) mlflow.log_metric("auc", test_auc, step=epoch) # 3. Model Checkpoint (Save ONLY if better) if test_auc > best_auc: best_auc = test_auc torch.save(model.state_dict(), best_model_path) print(f" πŸ† New Best Model Saved! (AUC: {best_auc:.4f})") mlflow.log_metric("best_auc", best_auc, step=epoch) # 4. Save Final Artifacts to MLflow # We upload the BEST model, not the last one if os.path.exists(best_model_path): mlflow.log_artifact(best_model_path) # Create a copy as the standard 'tox_model.pth' for the deploy script to find easily torch.save(torch.load(best_model_path, weights_only=False), "./data/tox_model.pth") print(f"βœ… Training Complete. Best Model (AUC: {best_auc:.4f}) logged to MLflow.") else: print("❌ Warning: No model saved (did AUC ever improve?)") if __name__ == "__main__": run_training()
MLflow Training Logs

06 // Theoretical PR-Curves vs. Safety Audits

Theoretical F1 points on a PR curve often ignore the cost of missed toxins. I disregarded theoretical "optimums" for a raw Threshold Analysis.

04-5_evaluate_PR_curve.py
import torch import matplotlib.pyplot as plt from sklearn.metrics import precision_recall_curve, average_precision_score from torch_geometric.loader import DataLoader from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool from torch.nn import Linear import numpy as np # --- 1. SETUP & MODEL DEFINITION --- # (Must match your training script exactly) HIDDEN_CHANNELS = 64 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class ToxicityGCN(torch.nn.Module): def __init__(self, num_features=9): super(ToxicityGCN, self).__init__() self.conv1 = GCNConv(num_features, HIDDEN_CHANNELS) self.conv2 = GCNConv(HIDDEN_CHANNELS, HIDDEN_CHANNELS) self.conv3 = GCNConv(HIDDEN_CHANNELS, HIDDEN_CHANNELS) self.lin = Linear(HIDDEN_CHANNELS * 2, 1) def forward(self, x, edge_index, batch): x = x.float() x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() x = self.conv3(x, edge_index) x_mean = global_mean_pool(x, batch) x_max = global_max_pool(x, batch) x = torch.cat([x_mean, x_max], dim=1) return self.lin(x) # --- 2. GENERATE PREDICTIONS --- def get_predictions(): print("πŸ“‚ Loading Data and Model...") # Load Data dataset_list = torch.load("./data/processed/tox21_graphs.pt", weights_only=False) dataset_list = [d for d in dataset_list if d.num_nodes > 0] # Use the last 20% (Test Set) split = int(len(dataset_list) * 0.8) test_data = dataset_list[split:] loader = DataLoader(test_data, batch_size=64, shuffle=False) # Load Model model = ToxicityGCN().to(DEVICE) # Try to load best model first, fall back to standard if os.path.exists("./data/tox_model_best.pth"): model.load_state_dict(torch.load("./data/tox_model_best.pth", map_location=DEVICE, weights_only=False)) print(" Using: tox_model_best.pth") else: model.load_state_dict(torch.load("./data/tox_model.pth", map_location=DEVICE, weights_only=False)) print(" Using: tox_model.pth") model.eval() y_true = [] y_scores = [] print("⚑ Running Inference...") with torch.no_grad(): for data in loader: data = data.to(DEVICE) out = model(data.x, data.edge_index, data.batch) # Get Probabilities (0.0 to 1.0) probs = torch.sigmoid(out).cpu().numpy() labels = data.y.cpu().numpy() y_scores.extend(probs) y_true.extend(labels) return np.array(y_true), np.array(y_scores) # --- 3. PLOT PRECISION-RECALL CURVE --- import os if __name__ == "__main__": y_true, y_scores = get_predictions() # Calculate Precision and Recall for all thresholds precision, recall, thresholds = precision_recall_curve(y_true, y_scores) avg_precision = average_precision_score(y_true, y_scores) print(f"\nπŸ“Š Average Precision (AP): {avg_precision:.4f}") # --- STRATEGY: Find the Best F1 Score --- # F1 = Harmonic mean of Precision and Recall # This helps us find the "sweet spot" threshold automatically f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8) best_idx = np.argmax(f1_scores) best_thresh = thresholds[best_idx] print(f"⭐ Optimal Threshold (Max F1): {best_thresh:.4f}") print(f" Recall at this point: {recall[best_idx]:.4f}") print(f" Precision at this point: {precision[best_idx]:.4f}") # Plot plt.figure(figsize=(8, 6)) plt.plot(recall, precision, marker='.', label=f'GCN (AP = {avg_precision:.2f})') # Mark the optimal point plt.scatter(recall[best_idx], precision[best_idx], marker='o', color='red', label='Optimal Threshold', zorder=5) plt.title('Precision-Recall Curve') plt.xlabel('Recall (Percentage of Toxic Molecules Found)') plt.ylabel('Precision (Percentage of Alerts that are actually Toxic)') plt.legend() plt.grid(True, alpha=0.3) plt.savefig('precision_recall_curve.png') print("\nβœ… Graph saved to precision_recall_curve.png") # Show the plot if you have a display, otherwise just save try: plt.show() except: pass
PR Curve Graph
F1 Point Analysis
04_evaluate_thresholds.py
import torch import numpy as np import pandas as pd from torch_geometric.loader import DataLoader from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool from torch.nn import Linear from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix import os # --- 1. SETUP --- BATCH_SIZE = 64 HIDDEN_CHANNELS = 64 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class ToxicityGCN(torch.nn.Module): def __init__(self, num_features=9): super(ToxicityGCN, self).__init__() self.conv1 = GCNConv(num_features, HIDDEN_CHANNELS) self.conv2 = GCNConv(HIDDEN_CHANNELS, HIDDEN_CHANNELS) self.conv3 = GCNConv(HIDDEN_CHANNELS, HIDDEN_CHANNELS) self.lin = Linear(HIDDEN_CHANNELS * 2, 1) def forward(self, x, edge_index, batch): x = x.float() x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() x = self.conv3(x, edge_index) x_mean = global_mean_pool(x, batch) x_max = global_max_pool(x, batch) x = torch.cat([x_mean, x_max], dim=1) return self.lin(x) def evaluate_thresholds(): print("πŸ“‚ Loading Data...") dataset_list = torch.load("./data/processed/tox21_graphs.pt", weights_only=False) dataset_list = [d for d in dataset_list if d.num_nodes > 0] # Use the Test Split (Last 20%) split = int(len(dataset_list) * 0.8) test_data = dataset_list[split:] # Create a balanced evaluation set (e.g., 50 Toxic, 50 Safe) to see the trade-offs clearly toxic_graphs = [d for d in test_data if d.y.item() == 1][:100] safe_graphs = [d for d in test_data if d.y.item() == 0][:100] eval_set = toxic_graphs + safe_graphs loader = DataLoader(eval_set, batch_size=32, shuffle=False) # Load Model model = ToxicityGCN().to(DEVICE) model_path = "./data/tox_model_best.pth" if not os.path.exists(model_path): model_path = "./data/tox_model.pth" model.load_state_dict(torch.load(model_path, map_location=DEVICE, weights_only=False)) model.eval() # Run Inference y_true = [] y_probs = [] print(f"⚑ Running Inference on {len(eval_set)} molecules ({len(toxic_graphs)} Toxic, {len(safe_graphs)} Safe)...") with torch.no_grad(): for data in loader: data = data.to(DEVICE) out = model(data.x, data.edge_index, data.batch) probs = torch.sigmoid(out).cpu().numpy() y_probs.extend(probs) y_true.extend(data.y.cpu().numpy()) y_true = np.array(y_true) y_probs = np.array(y_probs).flatten() # --- TEST THRESHOLDS --- results = [] thresholds = [0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90] print("\nπŸ“Š --- THRESHOLD ANALYSIS ---") print(f"{'Threshold':<10} | {'Accuracy':<10} | {'Precision':<10} | {'Recall':<10} | {'False Pos':<10} | {'False Neg':<10}") print("-" * 75) for t in thresholds: y_pred = (y_probs > t).astype(int) acc = accuracy_score(y_true, y_pred) prec = precision_score(y_true, y_pred, zero_division=0) rec = recall_score(y_true, y_pred) tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() results.append({ "Threshold": t, "Accuracy": acc, "Precision": prec, "Recall": rec, "FP": fp, "FN": fn }) print(f"{t:<10.2f} | {acc:<10.4f} | {prec:<10.4f} | {rec:<10.4f} | {fp:<10} | {fn:<10}") # Find the "Balanced" winner (closest to decent Recall/Precision mix) best_row = max(results, key=lambda x: x['Accuracy']) print("-" * 75) print(f"\nπŸ† Best Accuracy Threshold: {best_row['Threshold']}") print(f" (Correctly classified {int(best_row['Accuracy']*100)}% of the balanced set)") if __name__ == "__main__": evaluate_thresholds()
Threshold Audit Table

Selecting a 0.30 threshold secured a 0.9600 Recall.

07 // Explainable Ray Serve API

The model is deployed via Ray Serve to handle REST requests at scale. I integrated Captum's Integrated Gradients to provide atomic-level heatmaps for every prediction.

05_deploy.py
import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool from torch.nn import Linear from ray import serve from starlette.requests import Request import io import base64 import numpy as np # Explainability Imports from captum.attr import IntegratedGradients from rdkit import Chem from rdkit.Chem.Draw import rdMolDraw2D from torch_geometric.data import Data # --- 1. HELPER FUNCTIONS (Must match Training/Explain logic) --- def one_hot_encoding(x, permitted_list): if x not in permitted_list: x = permitted_list[-1] binary_encoding = [int(x == possible_value) for possible_value in permitted_list] return binary_encoding def get_atom_features(atom): permitted_atoms = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'I', 'Unknown'] atom_type = one_hot_encoding(atom.GetSymbol(), permitted_atoms) return torch.tensor(atom_type, dtype=torch.float) def smiles_to_graph(smiles): """Production Graph Converter""" mol = Chem.MolFromSmiles(smiles) if not mol: return None # Node Features features = [] for atom in mol.GetAtoms(): features.append(get_atom_features(atom)) x = torch.stack(features) # Edges edge_indices = [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() edge_indices.append([i, j]) edge_indices.append([j, i]) if not edge_indices: return None edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous() return Data(x=x, edge_index=edge_index) # --- 2. MODEL ARCHITECTURE --- class ToxicityGCN(torch.nn.Module): def __init__(self, num_features=9): super(ToxicityGCN, self).__init__() self.conv1 = GCNConv(num_features, 64) self.conv2 = GCNConv(64, 64) self.conv3 = GCNConv(64, 64) self.lin = Linear(64 * 2, 1) def forward(self, x, edge_index, batch=None): x = x.float() if batch is None: batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device) x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() x = self.conv3(x, edge_index) x_mean = global_mean_pool(x, batch) x_max = global_max_pool(x, batch) x = torch.cat([x_mean, x_max], dim=1) return self.lin(x) # --- 3. RAY DEPLOYMENT --- @serve.deployment(num_replicas=1, ray_actor_options={"num_cpus": 1}) class ToxicityDeployment: def __init__(self): print("⚑ Ray Serve: Loading Production Model...") self.device = torch.device("cpu") self.model = ToxicityGCN().to(self.device) # Load Best Model model_path = "./data/tox_model_best.pth" self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=False)) self.model.eval() def generate_heatmap(self, data, smiles): """Runs Integrated Gradients on the fly""" single_batch = torch.zeros(data.x.shape[0], dtype=torch.long, device=self.device) # Wrapper for Captum def model_forward(x_expanded): num_original = data.x.shape[0] num_repeats = x_expanded.shape[0] // num_original if num_repeats > 1: edge_indices = [data.edge_index + i * num_original for i in range(num_repeats)] exp_edge_index = torch.cat(edge_indices, dim=1) exp_batch = single_batch.repeat(num_repeats) else: exp_edge_index, exp_batch = data.edge_index, single_batch return self.model(x_expanded, exp_edge_index, exp_batch) # Attribute ig = IntegratedGradients(model_forward) attributions = ig.attribute(data.x.float().requires_grad_(), target=0, n_steps=20) weights = attributions.sum(dim=1).abs().detach().numpy() if weights.max() > 0: weights /= weights.max() # Draw mol = Chem.MolFromSmiles(smiles) highlight_atoms = {i: (1.0, 1.0 - float(w), 1.0 - float(w)) for i, w in enumerate(weights) if w > 0.3} drawer = rdMolDraw2D.MolDraw2DCairo(400, 300) drawer.DrawMolecule(mol, highlightAtoms=list(highlight_atoms.keys()), highlightAtomColors=highlight_atoms) drawer.FinishDrawing() return base64.b64encode(drawer.GetDrawingText()).decode('utf-8') async def __call__(self, http_request: Request): req = await http_request.json() smiles = req.get("smiles") # 1. Convert data = smiles_to_graph(smiles) if data is None: return {"error": "Invalid SMILES string"} data = data.to(self.device) # 2. Predict with torch.no_grad(): logits = self.model(data.x, data.edge_index, torch.zeros(data.x.shape[0], dtype=torch.long)) prob = torch.sigmoid(logits).item() # 3. Explain heatmap_b64 = self.generate_heatmap(data, smiles) # 4. Decision THRESHOLD = 0.30 is_toxic = prob > THRESHOLD return { "molecule": smiles, "toxicity_score": f"{prob:.4f}", "prediction": "TOXIC" if is_toxic else "SAFE", "image_base64": heatmap_b64 } toxicity_app = ToxicityDeployment.bind() if __name__ == "__main__": import ray # Manually set object store memory to 100MB ray.init(object_store_memory=100 * 1024 * 1024) serve.run(toxicity_app)

I then implemented a batch testing script to verify production performance and log real-time inference results.

06_test_batch.py
import requests import base64 import os URL = "http://localhost:8000/" OUTPUT_DIR = "./logs/images" os.makedirs(OUTPUT_DIR, exist_ok=True) # A Mix of Safe and Toxic Molecules chemicals = [ {"name": "Dinitrotoluene (Explosive)", "smiles": "Cc1ccc(cc1[N+](=O)[O-])[N+](=O)[O-]", "type": "TOXIC"}, {"name": "Citric Acid (Lemon)", "smiles": "OC(=O)CC(O)(C(=O)O)CC(=O)O", "type": "SAFE"}, {"name": "Benzene (Carcinogen)", "smiles": "c1ccccc1", "type": "TOXIC"}, {"name": "Aspirin (Medicine)", "smiles": "CC(=O)Oc1ccccc1C(=O)O", "type": "SAFE"}, {"name": "Parathion (Pesticide)", "smiles": "CCOP(=S)(OCC)Oc1ccc(cc1)[N+](=O)[O-]", "type": "TOXIC"}, {"name": "Glucose (Sugar)", "smiles": "C(C1C(C(C(C(O1)O)O)O)O)O", "type": "SAFE"} ] print(f"πŸ§ͺ Testing {len(chemicals)} chemicals against API...\n") for chem in chemicals: try: response = requests.post(URL, json={"smiles": chem['smiles']}) res_json = response.json() score = float(res_json['toxicity_score']) pred = res_json['prediction'] # Decode and Save Image img_data = base64.b64decode(res_json['image_base64']) filename = f"{OUTPUT_DIR}/{chem['name'].split()[0]}.png" with open(filename, "wb") as f: f.write(img_data) # Console Output status_icon = "βœ…" if pred == chem['type'] else "⚠️" print(f"{status_icon} {chem['name']:<25} | Score: {score:.4f} | Pred: {pred} (Saved to {filename})") except Exception as e: print(f"❌ Error processing {chem['name']}: {e}") print(f"\nπŸ“‚ All images saved to: {OUTPUT_DIR}")
Batch Inference Log
Streamlit UI Scan
AI Heatmap Visualization

Batch Testing Case Studies

True Positives (Correctly Flagged)

True Negatives (Correctly Classified Safe)

08 // Identifying the Topological Floor

Audit scripts identified a "Topological Floor" where structural signals wash out in simple graphs (7-15 atoms).

07_find_false_negative.py
import torch from torch_geometric.loader import DataLoader from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool from torch.nn import Linear from rdkit import Chem from rdkit.Chem import Draw # --- CONFIGURATION --- HIDDEN_CHANNELS = 64 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') THRESHOLD = 0.25 # --- MODEL DEFINITION --- class ToxicityGCN(torch.nn.Module): def __init__(self, num_features=9): super(ToxicityGCN, self).__init__() self.conv1 = GCNConv(num_features, HIDDEN_CHANNELS) self.conv2 = GCNConv(HIDDEN_CHANNELS, HIDDEN_CHANNELS) self.conv3 = GCNConv(HIDDEN_CHANNELS, HIDDEN_CHANNELS) self.lin = Linear(HIDDEN_CHANNELS * 2, 1) def forward(self, x, edge_index, batch): x = x.float() x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() x = self.conv3(x, edge_index) x_mean = global_mean_pool(x, batch) x_max = global_max_pool(x, batch) x = torch.cat([x_mean, x_max], dim=1) return self.lin(x) def find_failures(): print("πŸ“‚ Loading Data...") dataset_list = torch.load("./data/processed/tox21_graphs.pt", weights_only=False) # Filter Ghost Molecules dataset_list = [d for d in dataset_list if d.num_nodes > 0] # Split split_idx = int(len(dataset_list) * 0.8) test_data = dataset_list[split_idx:] print(f"πŸ” Scanning {len(test_data)} test molecules...") # Load Model model = ToxicityGCN().to(DEVICE) model.load_state_dict(torch.load("./data/tox_model_best.pth", map_location=DEVICE, weights_only=False)) model.eval() failures = [] with torch.no_grad(): for i, data in enumerate(test_data): data = data.to(DEVICE) # Add batch dimension if not hasattr(data, 'batch') or data.batch is None: data.batch = torch.zeros(data.x.shape[0], dtype=torch.long, device=DEVICE) out = model(data.x, data.edge_index, data.batch) prob = torch.sigmoid(out).item() # FALSE NEGATIVE CHECK if data.y.item() == 1.0 and prob <= THRESHOLD: failures.append({ "smiles": data.smiles, # <--- DIRECT ACCESS! NO LOOKUP NEEDED! "prob": prob, "atoms": data.num_nodes }) print("\n🚩 --- FALSE NEGATIVE REPORT ---") if not failures: print("πŸŽ‰ No False Negatives found!") else: for f in failures: print(f"πŸ“‰ Score: {f['prob']:.4f} | Atoms: {f['atoms']}") print(f"πŸ“œ SMILES: {f['smiles']}") mol = Chem.MolFromSmiles(f['smiles']) if mol: filename = f"culprit_{f['atoms']}atoms.png" Draw.MolToFile(mol, filename) print(f"πŸ–ΌοΈ Saved image to {filename}") print("-" * 30) if __name__ == "__main__": find_failures()
False Negative Extraction

False Negative Gallery

09 // Model Limitations & Future Outlook

  • Subset Training Constraints: The current weights were optimized on a specific subset of the Tox21 dataset.
  • The Arsenic Blind Spot: Rare elements like Arsenic cause confidence gaps due to feature underrepresentation.
  • False Negative Management: Future models must integrate global molecular descriptors to assist the topological GNN.