You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

56 lines
1.8 KiB
JavaScript

// Validate the JS LightGBM evaluator against LightGBM's own predictions.
// Run: node tools/blockbench-uv-packer/ml/validate_js_eval.mjs
import fs from 'node:fs';
import path from 'node:path';
const ROOT = path.dirname(new URL(import.meta.url).pathname.replace(/^\//, ''));
const model = JSON.parse(fs.readFileSync(path.join(ROOT, 'share_model.json'), 'utf8'));
const tests = JSON.parse(fs.readFileSync(path.join(ROOT, 'inference_test_cases.json'), 'utf8'));
function evalTree(node, x) {
while (!('v' in node)) {
const val = x[node.f];
let goLeft;
if (val === undefined || val === null || (typeof val === 'number' && isNaN(val))) {
goLeft = node.default_left;
} else if (node.d === '==') {
const cats = String(node.t).split('||').map(Number);
goLeft = cats.includes(val);
} else {
goLeft = val <= node.t;
}
node = goLeft ? node.l : node.r;
}
return node.v;
}
function predict(features) {
const x = new Array(model.feature_names.length);
for (let i = 0; i < model.feature_names.length; i++) {
x[i] = features[model.feature_names[i]];
}
let raw = 0;
for (const tree of model.trees) raw += evalTree(tree.root, x);
return 1 / (1 + Math.exp(-raw));
}
let maxErr = 0;
let worstCase = null;
for (const tc of tests) {
const got = predict(tc.features);
const err = Math.abs(got - tc.expected_prob);
if (err > maxErr) { maxErr = err; worstCase = { tc, got }; }
}
console.log(`tests: ${tests.length}`);
console.log(`max abs error: ${maxErr.toExponential(4)}`);
if (worstCase) {
console.log(`worst: expected=${worstCase.tc.expected_prob.toFixed(6)} got=${worstCase.got.toFixed(6)}`);
}
if (maxErr > 0.005) {
console.error('FAIL — JS evaluator disagrees with LightGBM');
process.exit(1);
}
console.log('PASS');