"""Phase 2: train the share-prediction classifier. Setup (one-time): pip install lightgbm pandas scikit-learn Run (from this directory): python train_share_classifier.py Inputs: pairs.csv (next to this script — generated by extract_pairs.mjs) Outputs: share_model.json LightGBM tree dump for the in-plugin JS evaluator eval_report.txt precision/recall sweep + feature importance inference_test_cases.json self-test cases the plugin verifies on load Held-out split is by MODEL, not by pair, so the eval numbers reflect generalization to unseen models — not just unseen pairs from already-seen models. """ from __future__ import annotations import json import os from pathlib import Path import lightgbm as lgb import numpy as np import pandas as pd from sklearn.metrics import ( average_precision_score, precision_recall_curve, roc_auc_score, ) from sklearn.model_selection import GroupShuffleSplit ROOT = Path(__file__).resolve().parent PAIRS = ROOT / "pairs.csv" MODEL_OUT = ROOT / "share_model.json" REPORT_OUT = ROOT / "eval_report.txt" CATEGORICAL = ["a_dir", "b_dir", "a_axis", "b_axis"] LABEL = "label" GROUP = "model" # Operating thresholds we want precision/recall reported at. THRESHOLDS = [0.50, 0.70, 0.80, 0.85, 0.90, 0.95] def main() -> None: df = pd.read_csv(PAIRS) print(f"Loaded {len(df):,} pairs across {df[GROUP].nunique()} models") feature_cols = [c for c in df.columns if c not in (LABEL, GROUP)] X = df[feature_cols] y = df[LABEL].astype(int) groups = df[GROUP] # Group-aware 80/20 split — same model never in both train and test. splitter = GroupShuffleSplit(n_splits=1, test_size=0.20, random_state=42) train_idx, test_idx = next(splitter.split(X, y, groups)) X_train, X_test = X.iloc[train_idx], X.iloc[test_idx] y_train, y_test = y.iloc[train_idx], y.iloc[test_idx] print( f"Train: {len(X_train):,} pairs / {df.iloc[train_idx][GROUP].nunique()} models |" f" Test: {len(X_test):,} pairs / {df.iloc[test_idx][GROUP].nunique()} models" ) # Coerce categoricals for c in CATEGORICAL: X_train[c] = X_train[c].astype("category") X_test[c] = X_test[c].astype("category") train_set = lgb.Dataset(X_train, label=y_train, categorical_feature=CATEGORICAL) valid_set = lgb.Dataset(X_test, label=y_test, categorical_feature=CATEGORICAL, reference=train_set) params = { "objective": "binary", "metric": ["binary_logloss", "auc"], "learning_rate": 0.05, "num_leaves": 63, "min_data_in_leaf": 50, "feature_fraction": 0.9, "bagging_fraction": 0.8, "bagging_freq": 5, "verbose": -1, } model = lgb.train( params, train_set, num_boost_round=600, valid_sets=[train_set, valid_set], valid_names=["train", "valid"], callbacks=[ lgb.early_stopping(stopping_rounds=30, verbose=True), lgb.log_evaluation(period=50), ], ) # ---- Eval ---- y_pred = model.predict(X_test, num_iteration=model.best_iteration) auc = roc_auc_score(y_test, y_pred) ap = average_precision_score(y_test, y_pred) lines = [] lines.append(f"Held-out models: {df.iloc[test_idx][GROUP].nunique()}") lines.append(f"Held-out pairs: {len(X_test):,}") lines.append(f"AUC: {auc:.4f}") lines.append(f"Average prec: {ap:.4f}") lines.append("") lines.append("Threshold sweep (held-out pairs):") lines.append(f"{'thresh':>8} {'prec':>8} {'recall':>8} {'f1':>8} {'kept%':>8}") precs, recs, thr = precision_recall_curve(y_test, y_pred) for t in THRESHOLDS: # find first threshold >= t mask = (y_pred >= t) kept = mask.mean() if mask.sum() == 0: lines.append(f"{t:>8.2f} {'-':>8} {'-':>8} {'-':>8} {kept * 100:>7.2f}%") continue tp = ((y_pred >= t) & (y_test == 1)).sum() fp = ((y_pred >= t) & (y_test == 0)).sum() fn = ((y_pred < t) & (y_test == 1)).sum() prec = tp / max(tp + fp, 1) rec = tp / max(tp + fn, 1) f1 = 2 * prec * rec / max(prec + rec, 1e-9) lines.append(f"{t:>8.2f} {prec:>8.4f} {rec:>8.4f} {f1:>8.4f} {kept * 100:>7.2f}%") lines.append("") lines.append("Top features by gain:") gain = model.feature_importance(importance_type="gain") names = model.feature_name() order = np.argsort(gain)[::-1] for i in order[:20]: lines.append(f" {names[i]:>22} {gain[i]:>14.0f}") report = "\n".join(lines) REPORT_OUT.write_text(report) print() print(report) print() print(f"Wrote {REPORT_OUT}") # ---- Export trees in a JS-evaluable form ---- dump = model.dump_model(num_iteration=model.best_iteration) # Preserve feature_names and categorical info for JS eval; trim heavy fields. export = { "version": 1, "objective": "binary", "feature_names": dump["feature_names"], "categorical_features": [ dump["feature_names"][i] for i in dump.get("pandas_categorical_index", []) if i < len(dump["feature_names"]) ] if "pandas_categorical_index" in dump else CATEGORICAL, "best_iteration": int(model.best_iteration), "trees": [_compact_tree(t) for t in dump["tree_info"]], } MODEL_OUT.write_text(json.dumps(export, separators=(",", ":"))) print(f"Wrote {MODEL_OUT} ({MODEL_OUT.stat().st_size / 1024:.1f} KB, " f"{len(export['trees'])} trees)") # ---- Self-test cases for the JS evaluator ---- # Pick 50 random held-out rows, save (feature dict, expected prediction) rng = np.random.default_rng(0) sample_idx = rng.choice(len(X_test), size=min(50, len(X_test)), replace=False) test_cases = [] for i in sample_idx: row = X_test.iloc[i] features = {} for col in feature_cols: v = row[col] if col in CATEGORICAL: v = int(v) else: v = float(v) features[col] = v test_cases.append({ "features": features, "expected_prob": float(y_pred[i]), }) tests_out = ROOT / "inference_test_cases.json" tests_out.write_text(json.dumps(test_cases, separators=(",", ":"))) print(f"Wrote {tests_out} ({len(test_cases)} cases)") def _compact_tree(tree: dict) -> dict: """Strip the LightGBM tree to only fields the JS evaluator needs.""" out = {"shrinkage": tree.get("shrinkage", 1.0), "root": _compact_node(tree["tree_structure"])} return out def _compact_node(node: dict) -> dict: if "leaf_value" in node: return {"v": node["leaf_value"]} out = { "f": node["split_feature"], "t": node["threshold"], "d": node["decision_type"], # "<=" or "==" (categorical) "default_left": node.get("default_left", True), "missing_type": node.get("missing_type", "None"), "l": _compact_node(node["left_child"]), "r": _compact_node(node["right_child"]), } return out if __name__ == "__main__": main()