Add inlining, DSP caching, fix TailCall-in-inline bug

Inlining: store IR bodies for all words, inline Call(id) when body <= 8 ops
and non-recursive. Convert TailCall back to Call when inlining (tail position
in callee is not tail position in caller -- found via compliance test failure
where inlined TailCall caused unreachable code after the call site).

DSP global caching: cache $dsp in WASM local 0 at function entry, use
local.get/set throughout, writeback before calls and at function exit.
Reduces global access instructions by ~30-40%.

323 unit tests + 11 compliance, all passing.
This commit is contained in:
2026-04-01 22:34:51 +02:00
parent 282f884a3d
commit b2cf289c36
3 changed files with 442 additions and 187 deletions
+190 -5
View File
@@ -7,6 +7,9 @@
//! 4. Dead code elimination
//! 5. Tail call detection
use std::collections::HashMap;
use crate::dictionary::WordId;
use crate::ir::IrOp;
/// Configuration for the optimization pipeline.
@@ -22,10 +25,16 @@ pub struct OptConfig {
pub strength_reduce: bool,
/// Enable dead code elimination.
pub dce: bool,
/// Enable inlining of small word bodies.
pub inline: bool,
}
/// Run all enabled optimization passes.
pub fn optimize(ops: Vec<IrOp>, config: &OptConfig) -> Vec<IrOp> {
pub fn optimize(
ops: Vec<IrOp>,
config: &OptConfig,
bodies: &HashMap<WordId, Vec<IrOp>>,
) -> Vec<IrOp> {
let mut ir = ops;
// Phase 1: simplify
@@ -42,7 +51,24 @@ pub fn optimize(ops: Vec<IrOp>, config: &OptConfig) -> Vec<IrOp> {
ir = peephole(ir);
}
// Phase 2: eliminate dead code
// Phase 2: inline then simplify again
if config.inline {
ir = inline(ir, bodies, 8);
}
if config.peephole {
ir = peephole(ir);
}
if config.constant_fold {
ir = constant_fold(ir);
}
if config.strength_reduce {
ir = strength_reduce(ir);
}
if config.peephole {
ir = peephole(ir);
}
// Phase 3: eliminate dead code
if config.dce {
ir = dce(ir);
}
@@ -50,7 +76,7 @@ pub fn optimize(ops: Vec<IrOp>, config: &OptConfig) -> Vec<IrOp> {
ir = peephole(ir);
}
// Phase 3: tail calls (must be last)
// Phase 4: tail calls (must be last)
if config.tail_call {
ir = tail_call_detect(ir);
}
@@ -378,7 +404,97 @@ fn dce(ops: Vec<IrOp>) -> Vec<IrOp> {
}
// ---------------------------------------------------------------------------
// Pass 5: Tail call detection
// Pass 6: Inlining
// ---------------------------------------------------------------------------
/// Inline small word bodies: replaces `Call(id)` with the word's IR body
/// if the body is small enough and not recursive.
fn inline(ops: Vec<IrOp>, bodies: &HashMap<WordId, Vec<IrOp>>, max_size: usize) -> Vec<IrOp> {
let mut out = Vec::new();
for op in ops {
match &op {
IrOp::Call(id) => {
if let Some(body) = bodies.get(id) {
if body.len() <= max_size && !contains_call_to(body, *id) {
// Inline the body, converting TailCall back to Call
// (tail position in the callee is not tail position in the caller)
for inlined_op in body {
match inlined_op {
IrOp::TailCall(tid) => out.push(IrOp::Call(*tid)),
other => out.push(other.clone()),
}
}
continue;
}
}
out.push(op);
}
_ => {
out.push(apply_to_bodies(op, &|inner| inline(inner, bodies, max_size)));
}
}
}
out
}
/// Check if an IR body contains a direct call to the given word (recursion guard).
fn contains_call_to(ops: &[IrOp], target: WordId) -> bool {
for op in ops {
match op {
IrOp::Call(id) | IrOp::TailCall(id) if *id == target => return true,
IrOp::If {
then_body,
else_body,
} => {
if contains_call_to(then_body, target) {
return true;
}
if let Some(eb) = else_body {
if contains_call_to(eb, target) {
return true;
}
}
}
IrOp::DoLoop { body, .. }
| IrOp::BeginUntil { body }
| IrOp::BeginAgain { body } => {
if contains_call_to(body, target) {
return true;
}
}
IrOp::BeginWhileRepeat { test, body } => {
if contains_call_to(test, target) || contains_call_to(body, target) {
return true;
}
}
IrOp::BeginDoubleWhileRepeat {
outer_test,
inner_test,
body,
after_repeat,
else_body,
} => {
if contains_call_to(outer_test, target)
|| contains_call_to(inner_test, target)
|| contains_call_to(body, target)
|| contains_call_to(after_repeat, target)
{
return true;
}
if let Some(eb) = else_body {
if contains_call_to(eb, target) {
return true;
}
}
}
_ => {}
}
}
false
}
// ---------------------------------------------------------------------------
// Pass 7: Tail call detection
// ---------------------------------------------------------------------------
/// Tail call detection: replace the last `Call` with `TailCall` when safe.
@@ -446,8 +562,24 @@ mod tests {
tail_call: true,
strength_reduce: true,
dce: true,
inline: false,
};
optimize(ops, &config)
optimize(ops, &config, &HashMap::new())
}
fn opt_with_inline(
ops: Vec<IrOp>,
bodies: &HashMap<WordId, Vec<IrOp>>,
) -> Vec<IrOp> {
let config = OptConfig {
peephole: true,
constant_fold: true,
tail_call: true,
strength_reduce: true,
dce: true,
inline: true,
};
optimize(ops, &config, bodies)
}
// Peephole tests
@@ -615,4 +747,57 @@ mod tests {
}]
);
}
// Inlining tests
#[test]
fn inline_simple() {
let mut bodies = HashMap::new();
// SQUARE = DUP *
bodies.insert(WordId(5), vec![IrOp::Dup, IrOp::Mul]);
let result = opt_with_inline(
vec![IrOp::PushI32(7), IrOp::Call(WordId(5))],
&bodies,
);
// After inlining: 7 DUP * (Dup isn't folded by constant folder)
assert_eq!(result, vec![IrOp::PushI32(7), IrOp::Dup, IrOp::Mul]);
}
#[test]
fn inline_folds_constants() {
let mut bodies = HashMap::new();
// ADD3 = 3 +
bodies.insert(WordId(5), vec![IrOp::PushI32(3), IrOp::Add]);
let result = opt_with_inline(
vec![IrOp::PushI32(5), IrOp::Call(WordId(5))],
&bodies,
);
// After inlining: PushI32(5) PushI32(3) Add => folded to PushI32(8)
assert_eq!(result, vec![IrOp::PushI32(8)]);
}
#[test]
fn no_inline_recursive() {
let mut bodies = HashMap::new();
bodies.insert(WordId(5), vec![IrOp::Dup, IrOp::Call(WordId(5))]);
let result = opt_with_inline(vec![IrOp::Call(WordId(5))], &bodies);
// Should NOT inline (recursive), but tail call detect may convert
assert!(matches!(result.last(), Some(IrOp::Call(WordId(5))) | Some(IrOp::TailCall(WordId(5)))));
}
#[test]
fn no_inline_large() {
let mut bodies = HashMap::new();
// Body with 9 ops (> max_size of 8)
bodies.insert(WordId(5), vec![IrOp::Dup; 9]);
let config = OptConfig {
peephole: false,
constant_fold: false,
tail_call: false,
strength_reduce: false,
dce: false,
inline: true,
};
let result = optimize(vec![IrOp::Call(WordId(5))], &config, &bodies);
assert_eq!(result, vec![IrOp::Call(WordId(5))]);
}
}