You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
7.1 KiB
Python
210 lines
7.1 KiB
Python
"""Phase 2: train the share-prediction classifier.
|
|
|
|
Setup (one-time):
|
|
pip install lightgbm pandas scikit-learn
|
|
|
|
Run (from this directory):
|
|
python train_share_classifier.py
|
|
|
|
Inputs:
|
|
pairs.csv (next to this script — generated by extract_pairs.mjs)
|
|
|
|
Outputs:
|
|
share_model.json LightGBM tree dump for the in-plugin JS evaluator
|
|
eval_report.txt precision/recall sweep + feature importance
|
|
inference_test_cases.json self-test cases the plugin verifies on load
|
|
|
|
Held-out split is by MODEL, not by pair, so the eval numbers reflect
|
|
generalization to unseen models — not just unseen pairs from already-seen models.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import lightgbm as lgb
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sklearn.metrics import (
|
|
average_precision_score,
|
|
precision_recall_curve,
|
|
roc_auc_score,
|
|
)
|
|
from sklearn.model_selection import GroupShuffleSplit
|
|
|
|
ROOT = Path(__file__).resolve().parent
|
|
PAIRS = ROOT / "pairs.csv"
|
|
MODEL_OUT = ROOT / "share_model.json"
|
|
REPORT_OUT = ROOT / "eval_report.txt"
|
|
|
|
CATEGORICAL = ["a_dir", "b_dir", "a_axis", "b_axis"]
|
|
LABEL = "label"
|
|
GROUP = "model"
|
|
|
|
# Operating thresholds we want precision/recall reported at.
|
|
THRESHOLDS = [0.50, 0.70, 0.80, 0.85, 0.90, 0.95]
|
|
|
|
|
|
def main() -> None:
|
|
df = pd.read_csv(PAIRS)
|
|
print(f"Loaded {len(df):,} pairs across {df[GROUP].nunique()} models")
|
|
|
|
feature_cols = [c for c in df.columns if c not in (LABEL, GROUP)]
|
|
X = df[feature_cols]
|
|
y = df[LABEL].astype(int)
|
|
groups = df[GROUP]
|
|
|
|
# Group-aware 80/20 split — same model never in both train and test.
|
|
splitter = GroupShuffleSplit(n_splits=1, test_size=0.20, random_state=42)
|
|
train_idx, test_idx = next(splitter.split(X, y, groups))
|
|
X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
|
|
y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
|
|
print(
|
|
f"Train: {len(X_train):,} pairs / {df.iloc[train_idx][GROUP].nunique()} models |"
|
|
f" Test: {len(X_test):,} pairs / {df.iloc[test_idx][GROUP].nunique()} models"
|
|
)
|
|
|
|
# Coerce categoricals
|
|
for c in CATEGORICAL:
|
|
X_train[c] = X_train[c].astype("category")
|
|
X_test[c] = X_test[c].astype("category")
|
|
|
|
train_set = lgb.Dataset(X_train, label=y_train, categorical_feature=CATEGORICAL)
|
|
valid_set = lgb.Dataset(X_test, label=y_test, categorical_feature=CATEGORICAL, reference=train_set)
|
|
|
|
params = {
|
|
"objective": "binary",
|
|
"metric": ["binary_logloss", "auc"],
|
|
"learning_rate": 0.05,
|
|
"num_leaves": 63,
|
|
"min_data_in_leaf": 50,
|
|
"feature_fraction": 0.9,
|
|
"bagging_fraction": 0.8,
|
|
"bagging_freq": 5,
|
|
"verbose": -1,
|
|
}
|
|
|
|
model = lgb.train(
|
|
params,
|
|
train_set,
|
|
num_boost_round=600,
|
|
valid_sets=[train_set, valid_set],
|
|
valid_names=["train", "valid"],
|
|
callbacks=[
|
|
lgb.early_stopping(stopping_rounds=30, verbose=True),
|
|
lgb.log_evaluation(period=50),
|
|
],
|
|
)
|
|
|
|
# ---- Eval ----
|
|
y_pred = model.predict(X_test, num_iteration=model.best_iteration)
|
|
auc = roc_auc_score(y_test, y_pred)
|
|
ap = average_precision_score(y_test, y_pred)
|
|
|
|
lines = []
|
|
lines.append(f"Held-out models: {df.iloc[test_idx][GROUP].nunique()}")
|
|
lines.append(f"Held-out pairs: {len(X_test):,}")
|
|
lines.append(f"AUC: {auc:.4f}")
|
|
lines.append(f"Average prec: {ap:.4f}")
|
|
lines.append("")
|
|
lines.append("Threshold sweep (held-out pairs):")
|
|
lines.append(f"{'thresh':>8} {'prec':>8} {'recall':>8} {'f1':>8} {'kept%':>8}")
|
|
precs, recs, thr = precision_recall_curve(y_test, y_pred)
|
|
for t in THRESHOLDS:
|
|
# find first threshold >= t
|
|
mask = (y_pred >= t)
|
|
kept = mask.mean()
|
|
if mask.sum() == 0:
|
|
lines.append(f"{t:>8.2f} {'-':>8} {'-':>8} {'-':>8} {kept * 100:>7.2f}%")
|
|
continue
|
|
tp = ((y_pred >= t) & (y_test == 1)).sum()
|
|
fp = ((y_pred >= t) & (y_test == 0)).sum()
|
|
fn = ((y_pred < t) & (y_test == 1)).sum()
|
|
prec = tp / max(tp + fp, 1)
|
|
rec = tp / max(tp + fn, 1)
|
|
f1 = 2 * prec * rec / max(prec + rec, 1e-9)
|
|
lines.append(f"{t:>8.2f} {prec:>8.4f} {rec:>8.4f} {f1:>8.4f} {kept * 100:>7.2f}%")
|
|
|
|
lines.append("")
|
|
lines.append("Top features by gain:")
|
|
gain = model.feature_importance(importance_type="gain")
|
|
names = model.feature_name()
|
|
order = np.argsort(gain)[::-1]
|
|
for i in order[:20]:
|
|
lines.append(f" {names[i]:>22} {gain[i]:>14.0f}")
|
|
|
|
report = "\n".join(lines)
|
|
REPORT_OUT.write_text(report)
|
|
print()
|
|
print(report)
|
|
print()
|
|
print(f"Wrote {REPORT_OUT}")
|
|
|
|
# ---- Export trees in a JS-evaluable form ----
|
|
dump = model.dump_model(num_iteration=model.best_iteration)
|
|
# Preserve feature_names and categorical info for JS eval; trim heavy fields.
|
|
export = {
|
|
"version": 1,
|
|
"objective": "binary",
|
|
"feature_names": dump["feature_names"],
|
|
"categorical_features": [
|
|
dump["feature_names"][i] for i in dump.get("pandas_categorical_index", [])
|
|
if i < len(dump["feature_names"])
|
|
] if "pandas_categorical_index" in dump else CATEGORICAL,
|
|
"best_iteration": int(model.best_iteration),
|
|
"trees": [_compact_tree(t) for t in dump["tree_info"]],
|
|
}
|
|
MODEL_OUT.write_text(json.dumps(export, separators=(",", ":")))
|
|
print(f"Wrote {MODEL_OUT} ({MODEL_OUT.stat().st_size / 1024:.1f} KB, "
|
|
f"{len(export['trees'])} trees)")
|
|
|
|
# ---- Self-test cases for the JS evaluator ----
|
|
# Pick 50 random held-out rows, save (feature dict, expected prediction)
|
|
rng = np.random.default_rng(0)
|
|
sample_idx = rng.choice(len(X_test), size=min(50, len(X_test)), replace=False)
|
|
test_cases = []
|
|
for i in sample_idx:
|
|
row = X_test.iloc[i]
|
|
features = {}
|
|
for col in feature_cols:
|
|
v = row[col]
|
|
if col in CATEGORICAL:
|
|
v = int(v)
|
|
else:
|
|
v = float(v)
|
|
features[col] = v
|
|
test_cases.append({
|
|
"features": features,
|
|
"expected_prob": float(y_pred[i]),
|
|
})
|
|
tests_out = ROOT / "inference_test_cases.json"
|
|
tests_out.write_text(json.dumps(test_cases, separators=(",", ":")))
|
|
print(f"Wrote {tests_out} ({len(test_cases)} cases)")
|
|
|
|
|
|
def _compact_tree(tree: dict) -> dict:
|
|
"""Strip the LightGBM tree to only fields the JS evaluator needs."""
|
|
out = {"shrinkage": tree.get("shrinkage", 1.0), "root": _compact_node(tree["tree_structure"])}
|
|
return out
|
|
|
|
|
|
def _compact_node(node: dict) -> dict:
|
|
if "leaf_value" in node:
|
|
return {"v": node["leaf_value"]}
|
|
out = {
|
|
"f": node["split_feature"],
|
|
"t": node["threshold"],
|
|
"d": node["decision_type"], # "<=" or "==" (categorical)
|
|
"default_left": node.get("default_left", True),
|
|
"missing_type": node.get("missing_type", "None"),
|
|
"l": _compact_node(node["left_child"]),
|
|
"r": _compact_node(node["right_child"]),
|
|
}
|
|
return out
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|