You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
MLUvSorter/ml/train_share_classifier.py

210 lines
7.1 KiB
Python

"""Phase 2: train the share-prediction classifier.
Setup (one-time):
pip install lightgbm pandas scikit-learn
Run (from this directory):
python train_share_classifier.py
Inputs:
pairs.csv (next to this script — generated by extract_pairs.mjs)
Outputs:
share_model.json LightGBM tree dump for the in-plugin JS evaluator
eval_report.txt precision/recall sweep + feature importance
inference_test_cases.json self-test cases the plugin verifies on load
Held-out split is by MODEL, not by pair, so the eval numbers reflect
generalization to unseen models — not just unseen pairs from already-seen models.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
import lightgbm as lgb
import numpy as np
import pandas as pd
from sklearn.metrics import (
average_precision_score,
precision_recall_curve,
roc_auc_score,
)
from sklearn.model_selection import GroupShuffleSplit
ROOT = Path(__file__).resolve().parent
PAIRS = ROOT / "pairs.csv"
MODEL_OUT = ROOT / "share_model.json"
REPORT_OUT = ROOT / "eval_report.txt"
CATEGORICAL = ["a_dir", "b_dir", "a_axis", "b_axis"]
LABEL = "label"
GROUP = "model"
# Operating thresholds we want precision/recall reported at.
THRESHOLDS = [0.50, 0.70, 0.80, 0.85, 0.90, 0.95]
def main() -> None:
df = pd.read_csv(PAIRS)
print(f"Loaded {len(df):,} pairs across {df[GROUP].nunique()} models")
feature_cols = [c for c in df.columns if c not in (LABEL, GROUP)]
X = df[feature_cols]
y = df[LABEL].astype(int)
groups = df[GROUP]
# Group-aware 80/20 split — same model never in both train and test.
splitter = GroupShuffleSplit(n_splits=1, test_size=0.20, random_state=42)
train_idx, test_idx = next(splitter.split(X, y, groups))
X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
print(
f"Train: {len(X_train):,} pairs / {df.iloc[train_idx][GROUP].nunique()} models |"
f" Test: {len(X_test):,} pairs / {df.iloc[test_idx][GROUP].nunique()} models"
)
# Coerce categoricals
for c in CATEGORICAL:
X_train[c] = X_train[c].astype("category")
X_test[c] = X_test[c].astype("category")
train_set = lgb.Dataset(X_train, label=y_train, categorical_feature=CATEGORICAL)
valid_set = lgb.Dataset(X_test, label=y_test, categorical_feature=CATEGORICAL, reference=train_set)
params = {
"objective": "binary",
"metric": ["binary_logloss", "auc"],
"learning_rate": 0.05,
"num_leaves": 63,
"min_data_in_leaf": 50,
"feature_fraction": 0.9,
"bagging_fraction": 0.8,
"bagging_freq": 5,
"verbose": -1,
}
model = lgb.train(
params,
train_set,
num_boost_round=600,
valid_sets=[train_set, valid_set],
valid_names=["train", "valid"],
callbacks=[
lgb.early_stopping(stopping_rounds=30, verbose=True),
lgb.log_evaluation(period=50),
],
)
# ---- Eval ----
y_pred = model.predict(X_test, num_iteration=model.best_iteration)
auc = roc_auc_score(y_test, y_pred)
ap = average_precision_score(y_test, y_pred)
lines = []
lines.append(f"Held-out models: {df.iloc[test_idx][GROUP].nunique()}")
lines.append(f"Held-out pairs: {len(X_test):,}")
lines.append(f"AUC: {auc:.4f}")
lines.append(f"Average prec: {ap:.4f}")
lines.append("")
lines.append("Threshold sweep (held-out pairs):")
lines.append(f"{'thresh':>8} {'prec':>8} {'recall':>8} {'f1':>8} {'kept%':>8}")
precs, recs, thr = precision_recall_curve(y_test, y_pred)
for t in THRESHOLDS:
# find first threshold >= t
mask = (y_pred >= t)
kept = mask.mean()
if mask.sum() == 0:
lines.append(f"{t:>8.2f} {'-':>8} {'-':>8} {'-':>8} {kept * 100:>7.2f}%")
continue
tp = ((y_pred >= t) & (y_test == 1)).sum()
fp = ((y_pred >= t) & (y_test == 0)).sum()
fn = ((y_pred < t) & (y_test == 1)).sum()
prec = tp / max(tp + fp, 1)
rec = tp / max(tp + fn, 1)
f1 = 2 * prec * rec / max(prec + rec, 1e-9)
lines.append(f"{t:>8.2f} {prec:>8.4f} {rec:>8.4f} {f1:>8.4f} {kept * 100:>7.2f}%")
lines.append("")
lines.append("Top features by gain:")
gain = model.feature_importance(importance_type="gain")
names = model.feature_name()
order = np.argsort(gain)[::-1]
for i in order[:20]:
lines.append(f" {names[i]:>22} {gain[i]:>14.0f}")
report = "\n".join(lines)
REPORT_OUT.write_text(report)
print()
print(report)
print()
print(f"Wrote {REPORT_OUT}")
# ---- Export trees in a JS-evaluable form ----
dump = model.dump_model(num_iteration=model.best_iteration)
# Preserve feature_names and categorical info for JS eval; trim heavy fields.
export = {
"version": 1,
"objective": "binary",
"feature_names": dump["feature_names"],
"categorical_features": [
dump["feature_names"][i] for i in dump.get("pandas_categorical_index", [])
if i < len(dump["feature_names"])
] if "pandas_categorical_index" in dump else CATEGORICAL,
"best_iteration": int(model.best_iteration),
"trees": [_compact_tree(t) for t in dump["tree_info"]],
}
MODEL_OUT.write_text(json.dumps(export, separators=(",", ":")))
print(f"Wrote {MODEL_OUT} ({MODEL_OUT.stat().st_size / 1024:.1f} KB, "
f"{len(export['trees'])} trees)")
# ---- Self-test cases for the JS evaluator ----
# Pick 50 random held-out rows, save (feature dict, expected prediction)
rng = np.random.default_rng(0)
sample_idx = rng.choice(len(X_test), size=min(50, len(X_test)), replace=False)
test_cases = []
for i in sample_idx:
row = X_test.iloc[i]
features = {}
for col in feature_cols:
v = row[col]
if col in CATEGORICAL:
v = int(v)
else:
v = float(v)
features[col] = v
test_cases.append({
"features": features,
"expected_prob": float(y_pred[i]),
})
tests_out = ROOT / "inference_test_cases.json"
tests_out.write_text(json.dumps(test_cases, separators=(",", ":")))
print(f"Wrote {tests_out} ({len(test_cases)} cases)")
def _compact_tree(tree: dict) -> dict:
"""Strip the LightGBM tree to only fields the JS evaluator needs."""
out = {"shrinkage": tree.get("shrinkage", 1.0), "root": _compact_node(tree["tree_structure"])}
return out
def _compact_node(node: dict) -> dict:
if "leaf_value" in node:
return {"v": node["leaf_value"]}
out = {
"f": node["split_feature"],
"t": node["threshold"],
"d": node["decision_type"], # "<=" or "==" (categorical)
"default_left": node.get("default_left", True),
"missing_type": node.get("missing_type", "None"),
"l": _compact_node(node["left_child"]),
"r": _compact_node(node["right_child"]),
}
return out
if __name__ == "__main__":
main()