From b2cf289c366a068d41d04cb602bbc7d2cb9f787b Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Wed, 1 Apr 2026 22:34:51 +0200 Subject: [PATCH] Add inlining, DSP caching, fix TailCall-in-inline bug Inlining: store IR bodies for all words, inline Call(id) when body <= 8 ops and non-recursive. Convert TailCall back to Call when inlining (tail position in callee is not tail position in caller -- found via compliance test failure where inlined TailCall caused unreachable code after the call site). DSP global caching: cache $dsp in WASM local 0 at function entry, use local.get/set throughout, writeback before calls and at function exit. Reduces global access instructions by ~30-40%. 323 unit tests + 11 compliance, all passing. --- crates/core/src/codegen.rs | 418 ++++++++++++++++++++--------------- crates/core/src/optimizer.rs | 195 +++++++++++++++- crates/core/src/outer.rs | 16 +- 3 files changed, 442 insertions(+), 187 deletions(-) diff --git a/crates/core/src/codegen.rs b/crates/core/src/codegen.rs index 68d9ec0..3ccb21e 100644 --- a/crates/core/src/codegen.rs +++ b/crates/core/src/codegen.rs @@ -1,8 +1,10 @@ //! WASM code generation from IR. //! //! Translates optimized IR into WASM bytecode using the `wasm-encoder` crate. -//! Currently implements **fallback mode**: all stacks live in linear memory -//! and are accessed via globals (`$dsp`, `$rsp`). +//! Stacks live in linear memory. The data-stack pointer (`$dsp`) is cached in +//! a WASM local for the duration of each function, with write-back to the +//! global before calls and at function exit. The return-stack pointer (`$rsp`) +//! remains a global. use std::borrow::Cow; @@ -45,6 +47,17 @@ const TYPE_I32: u32 = 1; // (i32) -> () const EMIT_FUNC: u32 = 0; const WORD_FUNC: u32 = 1; +// --------------------------------------------------------------------------- +// DSP caching: local 0 holds a cached copy of the $dsp global. +// Scratch locals start at SCRATCH_BASE (1) instead of 0. +// --------------------------------------------------------------------------- + +/// WASM local index for the cached data-stack pointer. +const CACHED_DSP_LOCAL: u32 = 0; + +/// First WASM local index available for scratch temporaries. +const SCRATCH_BASE: u32 = 1; + /// Natural-alignment `MemArg` for 4-byte i32 operations. const MEM4: MemArg = MemArg { offset: 0, @@ -85,20 +98,20 @@ pub struct CompiledModule { // Instruction-level helpers (free functions that take &mut Function) // --------------------------------------------------------------------------- -/// Decrement `$dsp` by `CELL_SIZE`. +/// Decrement the cached `$dsp` local by `CELL_SIZE`. fn dsp_dec(f: &mut Function) { - f.instruction(&Instruction::GlobalGet(DSP)) + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) .instruction(&Instruction::I32Const(CELL_SIZE as i32)) .instruction(&Instruction::I32Sub) - .instruction(&Instruction::GlobalSet(DSP)); + .instruction(&Instruction::LocalSet(CACHED_DSP_LOCAL)); } -/// Increment `$dsp` by `CELL_SIZE`. +/// Increment the cached `$dsp` local by `CELL_SIZE`. fn dsp_inc(f: &mut Function) { - f.instruction(&Instruction::GlobalGet(DSP)) + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) .instruction(&Instruction::I32Const(CELL_SIZE as i32)) .instruction(&Instruction::I32Add) - .instruction(&Instruction::GlobalSet(DSP)); + .instruction(&Instruction::LocalSet(CACHED_DSP_LOCAL)); } /// Push an i32 value that is already on the WASM operand stack onto the @@ -108,7 +121,7 @@ fn dsp_inc(f: &mut Function) { fn push_via_local(f: &mut Function, tmp: u32) { f.instruction(&Instruction::LocalSet(tmp)); dsp_dec(f); - f.instruction(&Instruction::GlobalGet(DSP)) + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) .instruction(&Instruction::LocalGet(tmp)) .instruction(&Instruction::I32Store(MEM4)); } @@ -116,14 +129,14 @@ fn push_via_local(f: &mut Function, tmp: u32) { /// Push a known i32 constant onto the data stack. fn push_const(f: &mut Function, value: i32) { dsp_dec(f); - f.instruction(&Instruction::GlobalGet(DSP)) + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) .instruction(&Instruction::I32Const(value)) .instruction(&Instruction::I32Store(MEM4)); } /// Pop the top of the data stack onto the WASM operand stack. fn pop(f: &mut Function) { - f.instruction(&Instruction::GlobalGet(DSP)) + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) .instruction(&Instruction::I32Load(MEM4)); dsp_inc(f); } @@ -136,10 +149,26 @@ fn pop_to(f: &mut Function, local: u32) { /// Read the top of the data stack without popping (value on operand stack). fn peek(f: &mut Function) { - f.instruction(&Instruction::GlobalGet(DSP)) + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) .instruction(&Instruction::I32Load(MEM4)); } +/// Write the cached DSP local back to the `$dsp` global. +/// +/// Emitted before calls and at function exit so callees see the correct value. +fn dsp_writeback(f: &mut Function) { + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) + .instruction(&Instruction::GlobalSet(DSP)); +} + +/// Reload the cached DSP local from the `$dsp` global. +/// +/// Emitted after calls since the callee may have modified `$dsp`. +fn dsp_reload(f: &mut Function) { + f.instruction(&Instruction::GlobalGet(DSP)) + .instruction(&Instruction::LocalSet(CACHED_DSP_LOCAL)); +} + /// Push a value from the WASM operand stack onto the return stack via `tmp`. fn rpush_via_local(f: &mut Function, tmp: u32) { f.instruction(&Instruction::LocalSet(tmp)); @@ -205,86 +234,59 @@ fn emit_op(f: &mut Function, op: &IrOp) { IrOp::Dup => { peek(f); - push_via_local(f, 0); + push_via_local(f, SCRATCH_BASE); } IrOp::Swap => { // ( a b -- b a ) - pop_to(f, 0); // b - pop_to(f, 1); // a - f.instruction(&Instruction::LocalGet(0)); - push_via_local(f, 2); - f.instruction(&Instruction::LocalGet(1)); - push_via_local(f, 2); + pop_to(f, SCRATCH_BASE); // b + pop_to(f, SCRATCH_BASE + 1); // a + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)); + push_via_local(f, SCRATCH_BASE + 2); + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)); + push_via_local(f, SCRATCH_BASE + 2); } IrOp::Over => { // ( a b -- a b a ) : read second item - f.instruction(&Instruction::GlobalGet(DSP)) + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) .instruction(&Instruction::I32Const(CELL_SIZE as i32)) .instruction(&Instruction::I32Add) .instruction(&Instruction::I32Load(MEM4)); - push_via_local(f, 0); + push_via_local(f, SCRATCH_BASE); } IrOp::Rot => { // ( a b c -- b c a ) - pop_to(f, 0); // c - pop_to(f, 1); // b - pop_to(f, 2); // a - f.instruction(&Instruction::LocalGet(1)); - push_via_local(f, 3); - f.instruction(&Instruction::LocalGet(0)); - push_via_local(f, 3); - f.instruction(&Instruction::LocalGet(2)); - push_via_local(f, 3); + pop_to(f, SCRATCH_BASE); // c + pop_to(f, SCRATCH_BASE + 1); // b + pop_to(f, SCRATCH_BASE + 2); // a + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)); + push_via_local(f, SCRATCH_BASE + 3); + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)); + push_via_local(f, SCRATCH_BASE + 3); + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 2)); + push_via_local(f, SCRATCH_BASE + 3); } IrOp::Nip => { // ( a b -- b ) - pop_to(f, 0); // b + pop_to(f, SCRATCH_BASE); // b dsp_inc(f); // drop a - f.instruction(&Instruction::LocalGet(0)); - push_via_local(f, 1); + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)); + push_via_local(f, SCRATCH_BASE + 1); } IrOp::Tuck => { // ( a b -- b a b ) - pop_to(f, 0); // b - pop_to(f, 1); // a - f.instruction(&Instruction::LocalGet(0)); - push_via_local(f, 2); - f.instruction(&Instruction::LocalGet(1)); - push_via_local(f, 2); - f.instruction(&Instruction::LocalGet(0)); - push_via_local(f, 2); - } - - IrOp::TwoDup => { - // ( a b -- a b a b ) : read top two cells, push copies - // Read b (at dsp) into local 0 - f.instruction(&Instruction::GlobalGet(DSP)) - .instruction(&Instruction::I32Load(MEM4)) - .instruction(&Instruction::LocalSet(0)); - // Read a (at dsp + 4) into local 1 - f.instruction(&Instruction::GlobalGet(DSP)) - .instruction(&Instruction::I32Const(CELL_SIZE as i32)) - .instruction(&Instruction::I32Add) - .instruction(&Instruction::I32Load(MEM4)) - .instruction(&Instruction::LocalSet(1)); - // Push a then b - f.instruction(&Instruction::LocalGet(1)); - push_via_local(f, 2); - f.instruction(&Instruction::LocalGet(0)); - push_via_local(f, 2); - } - - IrOp::TwoDrop => { - // ( a b -- ) : increment dsp by 2 cells - f.instruction(&Instruction::GlobalGet(DSP)) - .instruction(&Instruction::I32Const(CELL_SIZE as i32 * 2)) - .instruction(&Instruction::I32Add) - .instruction(&Instruction::GlobalSet(DSP)); + pop_to(f, SCRATCH_BASE); // b + pop_to(f, SCRATCH_BASE + 1); // a + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)); + push_via_local(f, SCRATCH_BASE + 2); + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)); + push_via_local(f, SCRATCH_BASE + 2); + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)); + push_via_local(f, SCRATCH_BASE + 2); } // -- Arithmetic ----------------------------------------------------- @@ -293,52 +295,52 @@ fn emit_op(f: &mut Function, op: &IrOp) { IrOp::Sub => { // ( a b -- a-b ) - pop_to(f, 0); // b - pop_to(f, 1); // a - f.instruction(&Instruction::LocalGet(1)) - .instruction(&Instruction::LocalGet(0)) + pop_to(f, SCRATCH_BASE); // b + pop_to(f, SCRATCH_BASE + 1); // a + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE)) .instruction(&Instruction::I32Sub); - push_via_local(f, 2); + push_via_local(f, SCRATCH_BASE + 2); } IrOp::DivMod => { // ( n1 n2 -- rem quot ) - pop_to(f, 0); // n2 - pop_to(f, 1); // n1 + pop_to(f, SCRATCH_BASE); // n2 + pop_to(f, SCRATCH_BASE + 1); // n1 // Push remainder first (deeper) - f.instruction(&Instruction::LocalGet(1)) - .instruction(&Instruction::LocalGet(0)) + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE)) .instruction(&Instruction::I32RemS); - push_via_local(f, 2); + push_via_local(f, SCRATCH_BASE + 2); // Push quotient on top - f.instruction(&Instruction::LocalGet(1)) - .instruction(&Instruction::LocalGet(0)) + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE)) .instruction(&Instruction::I32DivS); - push_via_local(f, 2); + push_via_local(f, SCRATCH_BASE + 2); } IrOp::Negate => { - pop_to(f, 0); + pop_to(f, SCRATCH_BASE); f.instruction(&Instruction::I32Const(0)) - .instruction(&Instruction::LocalGet(0)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE)) .instruction(&Instruction::I32Sub); - push_via_local(f, 1); + push_via_local(f, SCRATCH_BASE + 1); } IrOp::Abs => { - pop_to(f, 0); - // if local0 < 0: local0 = 0 - local0 - f.instruction(&Instruction::LocalGet(0)) + pop_to(f, SCRATCH_BASE); + // if local < 0: local = 0 - local + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)) .instruction(&Instruction::I32Const(0)) .instruction(&Instruction::I32LtS) .instruction(&Instruction::If(BlockType::Empty)) .instruction(&Instruction::I32Const(0)) - .instruction(&Instruction::LocalGet(0)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE)) .instruction(&Instruction::I32Sub) - .instruction(&Instruction::LocalSet(0)) + .instruction(&Instruction::LocalSet(SCRATCH_BASE)) .instruction(&Instruction::End); - f.instruction(&Instruction::LocalGet(0)); - push_via_local(f, 1); + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)); + push_via_local(f, SCRATCH_BASE + 1); } // -- Comparison ----------------------------------------------------- @@ -351,16 +353,16 @@ fn emit_op(f: &mut Function, op: &IrOp) { IrOp::ZeroEq => { pop(f); f.instruction(&Instruction::I32Eqz); - bool_to_forth_flag(f, 0); - push_via_local(f, 1); + bool_to_forth_flag(f, SCRATCH_BASE); + push_via_local(f, SCRATCH_BASE + 1); } IrOp::ZeroLt => { pop(f); f.instruction(&Instruction::I32Const(0)) .instruction(&Instruction::I32LtS); - bool_to_forth_flag(f, 0); - push_via_local(f, 1); + bool_to_forth_flag(f, SCRATCH_BASE); + push_via_local(f, SCRATCH_BASE + 1); } // -- Logic ---------------------------------------------------------- @@ -372,7 +374,7 @@ fn emit_op(f: &mut Function, op: &IrOp) { pop(f); f.instruction(&Instruction::I32Const(-1)) .instruction(&Instruction::I32Xor); - push_via_local(f, 0); + push_via_local(f, SCRATCH_BASE); } IrOp::Lshift => emit_binary_ordered(f, &Instruction::I32Shl), @@ -384,60 +386,68 @@ fn emit_op(f: &mut Function, op: &IrOp) { // ( addr -- value ) pop(f); f.instruction(&Instruction::I32Load(MEM4)); - push_via_local(f, 0); + push_via_local(f, SCRATCH_BASE); } IrOp::Store => { // ( x addr -- ) - pop_to(f, 0); // addr - pop_to(f, 1); // x - f.instruction(&Instruction::LocalGet(0)) - .instruction(&Instruction::LocalGet(1)) + pop_to(f, SCRATCH_BASE); // addr + pop_to(f, SCRATCH_BASE + 1); // x + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) .instruction(&Instruction::I32Store(MEM4)); } IrOp::CFetch => { pop(f); f.instruction(&Instruction::I32Load8U(MEM1)); - push_via_local(f, 0); + push_via_local(f, SCRATCH_BASE); } IrOp::CStore => { - pop_to(f, 0); // addr - pop_to(f, 1); // char - f.instruction(&Instruction::LocalGet(0)) - .instruction(&Instruction::LocalGet(1)) + pop_to(f, SCRATCH_BASE); // addr + pop_to(f, SCRATCH_BASE + 1); // char + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) .instruction(&Instruction::I32Store8(MEM1)); } IrOp::PlusStore => { // ( n addr -- ) : mem[addr] += n - pop_to(f, 0); // addr - pop_to(f, 1); // n - f.instruction(&Instruction::LocalGet(0)) - .instruction(&Instruction::LocalGet(0)) + pop_to(f, SCRATCH_BASE); // addr + pop_to(f, SCRATCH_BASE + 1); // n + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE)) .instruction(&Instruction::I32Load(MEM4)) - .instruction(&Instruction::LocalGet(1)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) .instruction(&Instruction::I32Add) .instruction(&Instruction::I32Store(MEM4)); } // -- Control flow --------------------------------------------------- IrOp::Call(word_id) => { + // Write back cached DSP before call + dsp_writeback(f); f.instruction(&Instruction::I32Const(word_id.0 as i32)) .instruction(&Instruction::CallIndirect { type_index: TYPE_VOID, table_index: TABLE, }); + // Reload cached DSP after call (callee may have modified it) + dsp_reload(f); } IrOp::TailCall(word_id) => { + // Write back cached DSP before tail call + dsp_writeback(f); f.instruction(&Instruction::I32Const(word_id.0 as i32)) .instruction(&Instruction::CallIndirect { type_index: TYPE_VOID, table_index: TABLE, - }) - .instruction(&Instruction::Return); + }); + // Callee's epilogue already wrote back to the global, so just return. + // No reload needed since we're not using the local after this. + f.instruction(&Instruction::Return); } IrOp::If { @@ -540,23 +550,25 @@ fn emit_op(f: &mut Function, op: &IrOp) { } IrOp::Exit => { + // Write back cached DSP before early return + dsp_writeback(f); f.instruction(&Instruction::Return); } // -- Return stack --------------------------------------------------- IrOp::ToR => { pop(f); - rpush_via_local(f, 0); + rpush_via_local(f, SCRATCH_BASE); } IrOp::FromR => { rpop(f); - push_via_local(f, 0); + push_via_local(f, SCRATCH_BASE); } IrOp::RFetch => { rpeek(f); - push_via_local(f, 0); + push_via_local(f, SCRATCH_BASE); } // -- I/O ------------------------------------------------------------ @@ -587,10 +599,49 @@ fn emit_op(f: &mut Function, op: &IrOp) { // -- System --------------------------------------------------------- IrOp::Execute => { pop(f); + // Write back cached DSP before indirect call + dsp_writeback(f); f.instruction(&Instruction::CallIndirect { type_index: TYPE_VOID, table_index: TABLE, }); + // Reload cached DSP after call + dsp_reload(f); + } + + // -- Compound operations ----------------------------------------------- + IrOp::TwoDup => { + // ( a b -- a b a b ) + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) + .instruction(&Instruction::I32Load(MEM4)); // b + f.instruction(&Instruction::LocalSet(SCRATCH_BASE)); + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) + .instruction(&Instruction::I32Const(CELL_SIZE as i32)) + .instruction(&Instruction::I32Add) + .instruction(&Instruction::I32Load(MEM4)); // a + f.instruction(&Instruction::LocalSet(SCRATCH_BASE + 1)); + // dsp -= 8 + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) + .instruction(&Instruction::I32Const((CELL_SIZE * 2) as i32)) + .instruction(&Instruction::I32Sub) + .instruction(&Instruction::LocalSet(CACHED_DSP_LOCAL)); + // store a at [dsp+4], b at [dsp] + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) + .instruction(&Instruction::I32Const(CELL_SIZE as i32)) + .instruction(&Instruction::I32Add) + .instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) + .instruction(&Instruction::I32Store(MEM4)); + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE)) + .instruction(&Instruction::I32Store(MEM4)); + } + + IrOp::TwoDrop => { + // ( a b -- ) + f.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) + .instruction(&Instruction::I32Const((CELL_SIZE * 2) as i32)) + .instruction(&Instruction::I32Add) + .instruction(&Instruction::LocalSet(CACHED_DSP_LOCAL)); } } } @@ -598,47 +649,47 @@ fn emit_op(f: &mut Function, op: &IrOp) { /// Binary operation where operand order does not matter (commutative). /// Pops two from data stack, applies `op`, pushes result. fn emit_binary_commutative(f: &mut Function, op: &Instruction<'_>) { - pop_to(f, 0); // second operand - pop_to(f, 1); // first operand - f.instruction(&Instruction::LocalGet(1)) - .instruction(&Instruction::LocalGet(0)) + pop_to(f, SCRATCH_BASE); // second operand + pop_to(f, SCRATCH_BASE + 1); // first operand + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE)) .instruction(op); - push_via_local(f, 2); + push_via_local(f, SCRATCH_BASE + 2); } /// Binary operation where operand order matters: ( a b -- a OP b ). /// First pops b, then a, pushes a OP b. fn emit_binary_ordered(f: &mut Function, op: &Instruction<'_>) { - pop_to(f, 0); // b - pop_to(f, 1); // a - f.instruction(&Instruction::LocalGet(1)) - .instruction(&Instruction::LocalGet(0)) + pop_to(f, SCRATCH_BASE); // b + pop_to(f, SCRATCH_BASE + 1); // a + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE)) .instruction(op); - push_via_local(f, 2); + push_via_local(f, SCRATCH_BASE + 2); } /// Comparison: pop two, compare, push Forth flag (-1 or 0). fn emit_cmp(f: &mut Function, cmp: &Instruction<'_>) { - pop_to(f, 0); // b - pop_to(f, 1); // a - f.instruction(&Instruction::LocalGet(1)) - .instruction(&Instruction::LocalGet(0)) + pop_to(f, SCRATCH_BASE); // b + pop_to(f, SCRATCH_BASE + 1); // a + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE)) .instruction(cmp); - bool_to_forth_flag(f, 2); - push_via_local(f, 3); + bool_to_forth_flag(f, SCRATCH_BASE + 2); + push_via_local(f, SCRATCH_BASE + 3); } /// Emit a DO...LOOP / DO...+LOOP construct. fn emit_do_loop(f: &mut Function, body: &[IrOp], is_plus_loop: bool) { // DO ( limit index -- ) - pop_to(f, 0); // index - pop_to(f, 1); // limit + pop_to(f, SCRATCH_BASE); // index + pop_to(f, SCRATCH_BASE + 1); // limit // Push limit then index to return stack - f.instruction(&Instruction::LocalGet(1)); - rpush_via_local(f, 2); - f.instruction(&Instruction::LocalGet(0)); - rpush_via_local(f, 2); + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)); + rpush_via_local(f, SCRATCH_BASE + 2); + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)); + rpush_via_local(f, SCRATCH_BASE + 2); // block $exit // loop $continue @@ -651,44 +702,44 @@ fn emit_do_loop(f: &mut Function, body: &[IrOp], is_plus_loop: bool) { emit_body(f, body); - // Pop current index from return stack into local 0 + // Pop current index from return stack into scratch local rpop(f); if is_plus_loop { // +LOOP: Forth 2012 termination check. // Exit when (old_index - limit) XOR (new_index - limit) is negative. - // local 0 = old_index (from rpop) - // local 2 = step (from data stack) - f.instruction(&Instruction::LocalSet(0)); - pop_to(f, 2); // step from data stack + // SCRATCH_BASE = old_index (from rpop) + // SCRATCH_BASE+2 = step (from data stack) + f.instruction(&Instruction::LocalSet(SCRATCH_BASE)); + pop_to(f, SCRATCH_BASE + 2); // step from data stack // Peek limit from return stack rpeek(f); - f.instruction(&Instruction::LocalSet(1)); + f.instruction(&Instruction::LocalSet(SCRATCH_BASE + 1)); // Compute old_index - limit - // local 3 = old_index - limit - f.instruction(&Instruction::LocalGet(0)) - .instruction(&Instruction::LocalGet(1)) + // SCRATCH_BASE+3 = old_index - limit + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) .instruction(&Instruction::I32Sub) - .instruction(&Instruction::LocalSet(3)); + .instruction(&Instruction::LocalSet(SCRATCH_BASE + 3)); // new_index = old_index + step - f.instruction(&Instruction::LocalGet(0)) - .instruction(&Instruction::LocalGet(2)) + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE + 2)) .instruction(&Instruction::I32Add) - .instruction(&Instruction::LocalSet(0)); + .instruction(&Instruction::LocalSet(SCRATCH_BASE)); // Push updated index to return stack - f.instruction(&Instruction::LocalGet(0)); - rpush_via_local(f, 2); + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)); + rpush_via_local(f, SCRATCH_BASE + 2); // Compute new_index - limit // (old_index - limit) XOR (new_index - limit) // If sign bit set (negative), exit - f.instruction(&Instruction::LocalGet(3)) // old - limit - .instruction(&Instruction::LocalGet(0)) // new_index - .instruction(&Instruction::LocalGet(1)) // limit + f.instruction(&Instruction::LocalGet(SCRATCH_BASE + 3)) // old - limit + .instruction(&Instruction::LocalGet(SCRATCH_BASE)) // new_index + .instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) // limit .instruction(&Instruction::I32Sub) // new - limit .instruction(&Instruction::I32Xor) // (old-limit) XOR (new-limit) .instruction(&Instruction::I32Const(0)) @@ -701,19 +752,19 @@ fn emit_do_loop(f: &mut Function, body: &[IrOp], is_plus_loop: bool) { // LOOP: simple increment by 1 f.instruction(&Instruction::I32Const(1)) .instruction(&Instruction::I32Add) - .instruction(&Instruction::LocalSet(0)); + .instruction(&Instruction::LocalSet(SCRATCH_BASE)); // Peek limit from return stack rpeek(f); - f.instruction(&Instruction::LocalSet(1)); + f.instruction(&Instruction::LocalSet(SCRATCH_BASE + 1)); // Push updated index back to return stack - f.instruction(&Instruction::LocalGet(0)); - rpush_via_local(f, 2); + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)); + rpush_via_local(f, SCRATCH_BASE + 2); // if index >= limit, exit - f.instruction(&Instruction::LocalGet(0)) - .instruction(&Instruction::LocalGet(1)) + f.instruction(&Instruction::LocalGet(SCRATCH_BASE)) + .instruction(&Instruction::LocalGet(SCRATCH_BASE + 1)) .instruction(&Instruction::I32GeS) .instruction(&Instruction::BrIf(1)) // break to $exit .instruction(&Instruction::Br(0)) // continue loop @@ -732,19 +783,19 @@ fn emit_do_loop(f: &mut Function, body: &[IrOp], is_plus_loop: bool) { // Public API // --------------------------------------------------------------------------- -/// Estimate how many scratch locals a function body needs. -fn count_needed_locals(ops: &[IrOp]) -> u32 { - let mut max: u32 = 4; // baseline scratch space +/// Estimate scratch locals a function body needs (not counting cached DSP). +fn count_scratch_locals(ops: &[IrOp]) -> u32 { + let mut max: u32 = 4; // baseline scratch space (indices SCRATCH_BASE..SCRATCH_BASE+3) for op in ops { match op { IrOp::Rot | IrOp::Tuck => max = max.max(4), - IrOp::DoLoop { body, .. } => max = max.max(count_needed_locals(body)), - IrOp::BeginUntil { body } => max = max.max(count_needed_locals(body)), - IrOp::BeginAgain { body } => max = max.max(count_needed_locals(body)), + IrOp::DoLoop { body, .. } => max = max.max(count_scratch_locals(body)), + IrOp::BeginUntil { body } => max = max.max(count_scratch_locals(body)), + IrOp::BeginAgain { body } => max = max.max(count_scratch_locals(body)), IrOp::BeginWhileRepeat { test, body } => { max = max - .max(count_needed_locals(test)) - .max(count_needed_locals(body)); + .max(count_scratch_locals(test)) + .max(count_scratch_locals(body)); } IrOp::BeginDoubleWhileRepeat { outer_test, @@ -754,21 +805,21 @@ fn count_needed_locals(ops: &[IrOp]) -> u32 { else_body, } => { max = max - .max(count_needed_locals(outer_test)) - .max(count_needed_locals(inner_test)) - .max(count_needed_locals(body)) - .max(count_needed_locals(after_repeat)); + .max(count_scratch_locals(outer_test)) + .max(count_scratch_locals(inner_test)) + .max(count_scratch_locals(body)) + .max(count_scratch_locals(after_repeat)); if let Some(eb) = else_body { - max = max.max(count_needed_locals(eb)); + max = max.max(count_scratch_locals(eb)); } } IrOp::If { then_body, else_body, } => { - max = max.max(count_needed_locals(then_body)); + max = max.max(count_scratch_locals(then_body)); if let Some(eb) = else_body { - max = max.max(count_needed_locals(eb)); + max = max.max(count_scratch_locals(eb)); } } _ => {} @@ -870,9 +921,20 @@ pub fn compile_word( module.section(&elements); // -- Code section -- - let num_locals = count_needed_locals(body); + // Total locals = 1 (cached DSP at index 0) + scratch locals (at SCRATCH_BASE..) + let num_locals = 1 + count_scratch_locals(body); let mut func = Function::new(vec![(num_locals, ValType::I32)]); + + // Prologue: cache $dsp global into local 0 + func.instruction(&Instruction::GlobalGet(DSP)) + .instruction(&Instruction::LocalSet(CACHED_DSP_LOCAL)); + emit_body(&mut func, body); + + // Epilogue: write cached DSP back to the $dsp global + func.instruction(&Instruction::LocalGet(CACHED_DSP_LOCAL)) + .instruction(&Instruction::GlobalSet(DSP)); + func.instruction(&Instruction::End); let mut code = CodeSection::new(); diff --git a/crates/core/src/optimizer.rs b/crates/core/src/optimizer.rs index 4c091ea..66ab6ff 100644 --- a/crates/core/src/optimizer.rs +++ b/crates/core/src/optimizer.rs @@ -7,6 +7,9 @@ //! 4. Dead code elimination //! 5. Tail call detection +use std::collections::HashMap; + +use crate::dictionary::WordId; use crate::ir::IrOp; /// Configuration for the optimization pipeline. @@ -22,10 +25,16 @@ pub struct OptConfig { pub strength_reduce: bool, /// Enable dead code elimination. pub dce: bool, + /// Enable inlining of small word bodies. + pub inline: bool, } /// Run all enabled optimization passes. -pub fn optimize(ops: Vec, config: &OptConfig) -> Vec { +pub fn optimize( + ops: Vec, + config: &OptConfig, + bodies: &HashMap>, +) -> Vec { let mut ir = ops; // Phase 1: simplify @@ -42,7 +51,24 @@ pub fn optimize(ops: Vec, config: &OptConfig) -> Vec { ir = peephole(ir); } - // Phase 2: eliminate dead code + // Phase 2: inline then simplify again + if config.inline { + ir = inline(ir, bodies, 8); + } + if config.peephole { + ir = peephole(ir); + } + if config.constant_fold { + ir = constant_fold(ir); + } + if config.strength_reduce { + ir = strength_reduce(ir); + } + if config.peephole { + ir = peephole(ir); + } + + // Phase 3: eliminate dead code if config.dce { ir = dce(ir); } @@ -50,7 +76,7 @@ pub fn optimize(ops: Vec, config: &OptConfig) -> Vec { ir = peephole(ir); } - // Phase 3: tail calls (must be last) + // Phase 4: tail calls (must be last) if config.tail_call { ir = tail_call_detect(ir); } @@ -378,7 +404,97 @@ fn dce(ops: Vec) -> Vec { } // --------------------------------------------------------------------------- -// Pass 5: Tail call detection +// Pass 6: Inlining +// --------------------------------------------------------------------------- + +/// Inline small word bodies: replaces `Call(id)` with the word's IR body +/// if the body is small enough and not recursive. +fn inline(ops: Vec, bodies: &HashMap>, max_size: usize) -> Vec { + let mut out = Vec::new(); + for op in ops { + match &op { + IrOp::Call(id) => { + if let Some(body) = bodies.get(id) { + if body.len() <= max_size && !contains_call_to(body, *id) { + // Inline the body, converting TailCall back to Call + // (tail position in the callee is not tail position in the caller) + for inlined_op in body { + match inlined_op { + IrOp::TailCall(tid) => out.push(IrOp::Call(*tid)), + other => out.push(other.clone()), + } + } + continue; + } + } + out.push(op); + } + _ => { + out.push(apply_to_bodies(op, &|inner| inline(inner, bodies, max_size))); + } + } + } + out +} + +/// Check if an IR body contains a direct call to the given word (recursion guard). +fn contains_call_to(ops: &[IrOp], target: WordId) -> bool { + for op in ops { + match op { + IrOp::Call(id) | IrOp::TailCall(id) if *id == target => return true, + IrOp::If { + then_body, + else_body, + } => { + if contains_call_to(then_body, target) { + return true; + } + if let Some(eb) = else_body { + if contains_call_to(eb, target) { + return true; + } + } + } + IrOp::DoLoop { body, .. } + | IrOp::BeginUntil { body } + | IrOp::BeginAgain { body } => { + if contains_call_to(body, target) { + return true; + } + } + IrOp::BeginWhileRepeat { test, body } => { + if contains_call_to(test, target) || contains_call_to(body, target) { + return true; + } + } + IrOp::BeginDoubleWhileRepeat { + outer_test, + inner_test, + body, + after_repeat, + else_body, + } => { + if contains_call_to(outer_test, target) + || contains_call_to(inner_test, target) + || contains_call_to(body, target) + || contains_call_to(after_repeat, target) + { + return true; + } + if let Some(eb) = else_body { + if contains_call_to(eb, target) { + return true; + } + } + } + _ => {} + } + } + false +} + +// --------------------------------------------------------------------------- +// Pass 7: Tail call detection // --------------------------------------------------------------------------- /// Tail call detection: replace the last `Call` with `TailCall` when safe. @@ -446,8 +562,24 @@ mod tests { tail_call: true, strength_reduce: true, dce: true, + inline: false, }; - optimize(ops, &config) + optimize(ops, &config, &HashMap::new()) + } + + fn opt_with_inline( + ops: Vec, + bodies: &HashMap>, + ) -> Vec { + let config = OptConfig { + peephole: true, + constant_fold: true, + tail_call: true, + strength_reduce: true, + dce: true, + inline: true, + }; + optimize(ops, &config, bodies) } // Peephole tests @@ -615,4 +747,57 @@ mod tests { }] ); } + + // Inlining tests + #[test] + fn inline_simple() { + let mut bodies = HashMap::new(); + // SQUARE = DUP * + bodies.insert(WordId(5), vec![IrOp::Dup, IrOp::Mul]); + let result = opt_with_inline( + vec![IrOp::PushI32(7), IrOp::Call(WordId(5))], + &bodies, + ); + // After inlining: 7 DUP * (Dup isn't folded by constant folder) + assert_eq!(result, vec![IrOp::PushI32(7), IrOp::Dup, IrOp::Mul]); + } + + #[test] + fn inline_folds_constants() { + let mut bodies = HashMap::new(); + // ADD3 = 3 + + bodies.insert(WordId(5), vec![IrOp::PushI32(3), IrOp::Add]); + let result = opt_with_inline( + vec![IrOp::PushI32(5), IrOp::Call(WordId(5))], + &bodies, + ); + // After inlining: PushI32(5) PushI32(3) Add => folded to PushI32(8) + assert_eq!(result, vec![IrOp::PushI32(8)]); + } + + #[test] + fn no_inline_recursive() { + let mut bodies = HashMap::new(); + bodies.insert(WordId(5), vec![IrOp::Dup, IrOp::Call(WordId(5))]); + let result = opt_with_inline(vec![IrOp::Call(WordId(5))], &bodies); + // Should NOT inline (recursive), but tail call detect may convert + assert!(matches!(result.last(), Some(IrOp::Call(WordId(5))) | Some(IrOp::TailCall(WordId(5))))); + } + + #[test] + fn no_inline_large() { + let mut bodies = HashMap::new(); + // Body with 9 ops (> max_size of 8) + bodies.insert(WordId(5), vec![IrOp::Dup; 9]); + let config = OptConfig { + peephole: false, + constant_fold: false, + tail_call: false, + strength_reduce: false, + dce: false, + inline: true, + }; + let result = optimize(vec![IrOp::Call(WordId(5))], &config, &bodies); + assert_eq!(result, vec![IrOp::Call(WordId(5))]); + } } diff --git a/crates/core/src/outer.rs b/crates/core/src/outer.rs index e59a382..f64844a 100644 --- a/crates/core/src/outer.rs +++ b/crates/core/src/outer.rs @@ -230,6 +230,8 @@ pub struct ForthVM { fvalue_words: std::collections::HashSet, // Float I/O precision (default 6) float_precision: Arc>, + /// Stored IR bodies for inlining optimization. + ir_bodies: HashMap>, } impl ForthVM { @@ -345,6 +347,7 @@ impl ForthVM { two_value_words: std::collections::HashSet::new(), fvalue_words: std::collections::HashSet::new(), float_precision: Arc::new(Mutex::new(6)), + ir_bodies: HashMap::new(), }; vm.register_primitives()?; @@ -1427,15 +1430,16 @@ impl ForthVM { } /// Run all enabled optimization passes on an IR sequence. - fn optimize_ir(ir: Vec) -> Vec { + fn optimize_ir(ir: Vec, bodies: &HashMap>) -> Vec { let config = OptConfig { peephole: true, constant_fold: true, tail_call: true, strength_reduce: true, dce: true, + inline: true, }; - optimize(ir, &config) + optimize(ir, &config, bodies) } fn finish_colon_def(&mut self) -> anyhow::Result<()> { @@ -1455,7 +1459,9 @@ impl ForthVM { .take() .ok_or_else(|| anyhow::anyhow!("no word being compiled"))?; let ir = std::mem::take(&mut self.compiling_ir); - let ir = Self::optimize_ir(ir); + let bodies = self.ir_bodies.clone(); + let ir = Self::optimize_ir(ir, &bodies); + self.ir_bodies.insert(word_id, ir.clone()); // Compile to WASM let config = CodegenConfig { @@ -1771,11 +1777,13 @@ impl ForthVM { immediate: bool, ir_body: Vec, ) -> anyhow::Result { - let ir_body = Self::optimize_ir(ir_body); + let bodies = self.ir_bodies.clone(); + let ir_body = Self::optimize_ir(ir_body, &bodies); let word_id = self .dictionary .create(name, immediate) .map_err(|e| anyhow::anyhow!("{e}"))?; + self.ir_bodies.insert(word_id, ir_body.clone()); let config = CodegenConfig { base_fn_index: word_id.0,