Add utility to add tunable search parameters and fix git commit-bench alias to parse message from chained short options and --message (bench 5739439)

This commit is contained in:
2026-05-11 17:51:05 +02:00
parent 8c8558382a
commit dbf8fcc410
2 changed files with 346 additions and 5 deletions

334
scripts/add_tunable.py Executable file
View File

@@ -0,0 +1,334 @@
#!/usr/bin/env python3
"""
Add a tunable parameter to src/heimdall/util/tunables.nim.
Edits:
1. SearchParametersObj field (creating or extending tuples as needed)
2. addTunableParameter call in initTunableParameters
3. setParameterBody case branch
4. getParameter case branch
Path syntax:
foo -> top-level scalar (foo*: int)
foo.bar -> nested in tuple[bar: int]; bar is grouped with siblings
of identical type when possible
foo.bar.baz -> deeper nesting
foo.bar[Pawn] -> array index (the array field must already exist; the
struct is left untouched)
Usage:
scripts/add_tunable.py NAME PATH MIN MAX DEFAULT [--quantized] [--float]
Examples:
scripts/add_tunable.py NewMargin newMargin 1 200 100
scripts/add_tunable.py NewMarginQuiet newMargin.quiet 1 200 100
scripts/add_tunable.py NewMarginNoisy newMargin.noisy 1 200 100
scripts/add_tunable.py NewLMRBase newLmrBase 200 800 400 --float
scripts/add_tunable.py NewDivisor newDivisor 1024 8192 4096 --quantized
scripts/add_tunable.py SEEOrdKingWeight seeWeights.ordering[King] 0 0 0
"""
import argparse
import re
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent.parent
TUNABLES_PATH = REPO_ROOT / "src" / "heimdall" / "util" / "tunables.nim"
INDENT4 = " "
INDENT8 = " "
INDENT12 = " "
IDENT_RE = re.compile(r"^[a-zA-Z_]\w*$")
FIELD_RE = re.compile(r"^ {8}(\w+)\*:\s*(.+)$")
# -- tuple parsing/serialization --------------------------------------------
def split_top_commas(body: str) -> list[str]:
parts, current, depth = [], "", 0
for c in body:
if c == "[":
depth += 1
elif c == "]":
depth -= 1
if c == "," and depth == 0:
parts.append(current.strip())
current = ""
else:
current += c
if current.strip():
parts.append(current.strip())
return parts
def find_top_colon(s: str) -> int:
depth = 0
for i, c in enumerate(s):
if c == "[":
depth += 1
elif c == "]":
depth -= 1
elif c == ":" and depth == 0:
return i
return -1
def parse_tuple_groups(body: str) -> list[tuple[list[str], str]]:
groups: list[tuple[list[str], str]] = []
pending: list[str] = []
for part in split_top_commas(body):
ci = find_top_colon(part)
if ci == -1:
pending.append(part.strip())
continue
names = pending + [n.strip() for n in part[:ci].split(",")]
groups.append((names, part[ci + 1:].strip()))
pending = []
if pending:
raise ValueError(f"trailing names without type: {pending}")
return groups
def serialize_tuple(groups: list[tuple[list[str], str]]) -> str:
return "tuple[" + ", ".join(", ".join(ns) + ": " + t for ns, t in groups) + "]"
def merge_adjacent(groups: list[tuple[list[str], str]]) -> list[tuple[list[str], str]]:
out: list[tuple[list[str], str]] = []
for ns, t in groups:
if out and out[-1][1] == t:
out[-1] = (out[-1][0] + ns, t)
else:
out.append((list(ns), t))
return out
def make_nested_tuple(path: list[str], leaf_type: str) -> str:
if not path:
return leaf_type
return f"tuple[{path[0]}: {make_nested_tuple(path[1:], leaf_type)}]"
def extend_tuple_type(tuple_str: str, path: list[str], leaf_type: str) -> str:
if not (tuple_str.startswith("tuple[") and tuple_str.endswith("]")):
raise ValueError(f"cannot descend into non-tuple type: {tuple_str!r}")
groups = parse_tuple_groups(tuple_str[len("tuple["):-1])
head, rest = path[0], path[1:]
found_gi = next((gi for gi, (ns, _) in enumerate(groups) if head in ns), -1)
if found_gi == -1:
new_type = leaf_type if not rest else make_nested_tuple(rest, leaf_type)
merged = False
for ns, t in groups:
if t == new_type:
ns.append(head)
merged = True
break
if not merged:
groups.append(([head], new_type))
return serialize_tuple(merge_adjacent(groups))
names, group_type = groups[found_gi]
if not rest:
raise ValueError(f"field path already exists with type {group_type!r}")
new_inner = extend_tuple_type(group_type, rest, leaf_type)
if len(names) > 1:
names.remove(head)
groups.insert(found_gi + 1, ([head], new_inner))
else:
groups[found_gi] = (names, new_inner)
return serialize_tuple(merge_adjacent(groups))
# -- path parsing -----------------------------------------------------------
def parse_path(path_str: str) -> tuple[list[str], str | None]:
index: str | None = None
if "[" in path_str:
prefix, rest = path_str.split("[", 1)
if not rest.endswith("]"):
raise ValueError(f"unmatched bracket in path: {path_str!r}")
index = rest[:-1].strip()
path_str = prefix
if not index:
raise ValueError("empty array index")
components = path_str.split(".")
if not components or any(not c for c in components):
raise ValueError(f"invalid path: {path_str!r}")
for c in components:
if not IDENT_RE.match(c):
raise ValueError(f"invalid field name in path: {c!r}")
return components, index
# -- file edits -------------------------------------------------------------
def find_struct_range(lines: list[str]) -> tuple[int, int]:
start = next(
(i for i, l in enumerate(lines)
if re.match(r"^\s+SearchParametersObj\*\s*=\s*object", l)),
None,
)
if start is None:
raise RuntimeError("couldn't find SearchParametersObj definition")
i = start + 1
last_field = i
while i < len(lines):
stripped = lines[i].rstrip()
bare = stripped.lstrip()
if bare == "" or bare.startswith("#"):
i += 1
continue
if FIELD_RE.match(stripped):
last_field = i
i += 1
continue
break
return start, last_field
def modify_struct(lines: list[str], path: list[str],
leaf_type: str, has_index: bool) -> list[str]:
if has_index:
return lines
_, last_field = find_struct_range(lines)
head = path[0]
head_match: tuple[int, re.Match[str]] | None = None
for i in range(last_field + 1):
m = FIELD_RE.match(lines[i].rstrip())
if m and m.group(1) == head:
head_match = (i, m)
break
if head_match is None:
rest = path[1:]
new_type = leaf_type if not rest else make_nested_tuple(rest, leaf_type)
new_line = f"{INDENT8}{head}*: {new_type}\n"
return lines[:last_field + 1] + [new_line] + lines[last_field + 1:]
i, m = head_match
current_type = m.group(2)
rest = path[1:]
if not rest:
raise ValueError(
f"field {head!r} already exists with type {current_type!r}"
)
new_type = extend_tuple_type(current_type, rest, leaf_type)
return lines[:i] + [f"{INDENT8}{head}*: {new_type}\n"] + lines[i + 1:]
def insert_add_tunable(lines: list[str], name: str, min_v: int, max_v: int,
default: int, quantized: bool) -> list[str]:
last_idx = max(
(i for i, l in enumerate(lines)
if "addTunableParameter(" in l and not l.lstrip().startswith("#")),
default=-1,
)
if last_idx == -1:
raise RuntimeError("no addTunableParameter calls found")
extra = ", true" if quantized else ""
new_line = f'{INDENT4}addTunableParameter("{name}", {min_v}, {max_v}, {default}{extra})\n'
return lines[:last_idx + 1] + [new_line] + lines[last_idx + 1:]
def field_access(path: list[str], index: str | None) -> str:
expr = ".".join(path)
if index is not None:
expr += f"[{index}]"
return expr
def insert_case_branches(lines: list[str], name: str, path: list[str],
index: str | None, is_float: bool) -> list[str]:
expr = field_access(path, index)
setter_rhs = "value / 1000" if is_float else "value"
getter_rhs = f"int(self.{expr} * 1000)" if is_float else f"self.{expr}"
setter_branch = (
f'{INDENT8}of "{name}":\n'
f"{INDENT12}self.{expr} = {setter_rhs}\n"
)
getter_branch = (
f'{INDENT8}of "{name}":\n'
f"{INDENT12}{getter_rhs}\n"
)
positions = [
i for i, l in enumerate(lines)
if l.rstrip() == f"{INDENT8}else:"
and i + 1 < len(lines)
and "invalid tunable parameter" in lines[i + 1]
]
if len(positions) != 2:
raise RuntimeError(
f"expected 2 invalid-tunable else clauses, found {len(positions)}"
)
setter_pos, getter_pos = positions # file order: setter first, getter second
result = lines[:]
result = result[:getter_pos] + [getter_branch] + result[getter_pos:]
result = result[:setter_pos] + [setter_branch] + result[setter_pos:]
return result
# -- main -------------------------------------------------------------------
def main() -> int:
p = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument("name", help="SPSA parameter name (e.g. NewMarginQuiet)")
p.add_argument("path", help="Field path (e.g. newMargin.quiet)")
p.add_argument("min", type=int)
p.add_argument("max", type=int)
p.add_argument("default", type=int)
p.add_argument("--quantized", action="store_true",
help="Mark the parameter as quantized")
p.add_argument("--float", dest="is_float", action="store_true",
help="Leaf type is float, scaled by 1000")
args = p.parse_args()
if args.min > args.max:
p.error("min must be <= max")
if not (args.min <= args.default <= args.max):
p.error("default must be within [min, max]")
try:
path, index = parse_path(args.path)
except ValueError as e:
p.error(str(e))
leaf_type = "float" if args.is_float else "int"
original = TUNABLES_PATH.read_text()
if f'"{args.name}"' in original:
print(f"error: parameter name {args.name!r} already appears in the file",
file=sys.stderr)
return 1
lines = original.splitlines(keepends=True)
try:
lines = modify_struct(lines, path, leaf_type, has_index=(index is not None))
lines = insert_add_tunable(
lines, args.name, args.min, args.max, args.default, args.quantized
)
lines = insert_case_branches(lines, args.name, path, index, args.is_float)
except (ValueError, RuntimeError) as e:
print(f"error: {e}", file=sys.stderr)
return 1
TUNABLES_PATH.write_text("".join(lines))
rel = TUNABLES_PATH.relative_to(REPO_ROOT)
print(f"added tunable {args.name!r} to {rel} (run `git diff -- {rel}` to review)")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -13,19 +13,26 @@ fi
echo "Bench: $BENCH"
# Find -m flag and append bench to the message
# Find the message flag and append bench to the message.
# Supports: -m, --message, --message=..., and combined short flags ending in m (-am, -Sam, ...).
args=()
found_message=false
for ((i=1; i<=$#; i++)); do
i=1
while (( i <= $# )); do
arg="${!i}"
if [[ "$arg" == "-m" ]] && (( i < $# )); then
if [[ ( "$arg" == "-m" || "$arg" == "--message" || "$arg" =~ ^-[a-zA-Z]*m$ ) ]] && (( i < $# )); then
next=$((i+1))
args+=("-m" "${!next} (bench $BENCH)")
args+=("$arg" "${!next} (bench $BENCH)")
found_message=true
i=$((i+2))
continue
elif [[ "$arg" =~ ^--message=(.*)$ ]]; then
args+=("--message=${BASH_REMATCH[1]} (bench $BENCH)")
found_message=true
((i++))
else
args+=("$arg")
fi
i=$((i+1))
done
if ! $found_message; then