From 7cd3d5fb4e5a0b7aa5dc045e83ce875b94478e6e Mon Sep 17 00:00:00 2001 From: Ralf Anton Beier Date: Tue, 19 May 2026 21:13:51 +0200 Subject: [PATCH] =?UTF-8?q?feat(opt):=20=C3=A6graph=20i64=20ops=20+=208=20?= =?UTF-8?q?new=20identity=20rules=20+=20commutativity=20helpers=20(v1.1.0?= =?UTF-8?q?=20Track=20C,=20rebased)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Track C's content rebased onto post-Track-B main. Track B's cost-driven extract is preserved (it's already merged on main as PR #134); this PR adds i64 op variants + 8 i64 identity rules + commutativity helpers on top. ## What lands 11 new Op variants: I64Add/Sub/Mul/And/Or/Xor/Shl/ShrS/ShrU/Eq/Eqz. 8 new identity rules (i64 add/or/and/mul/xor + 3 shift-by-zero). Op::is_commutative() + EGraph::canonicalize_commutative() helpers. ## Tests 25 egraph tests pass (was 18 after Track B). 1 ignored (commutativity end-to-end — needs insertion-time normalization, v1.1.1). Trace: REQ-3 --- loom-core/src/egraph.rs | 696 ++++++++++++++++++++++++++-------------- loom-core/src/lib.rs | 42 ++- 2 files changed, 496 insertions(+), 242 deletions(-) diff --git a/loom-core/src/egraph.rs b/loom-core/src/egraph.rs index b1b2b91..bbd3754 100644 --- a/loom-core/src/egraph.rs +++ b/loom-core/src/egraph.rs @@ -9,13 +9,14 @@ //! substrate — which is exactly what LOOM's "provably correct" mission //! requires. //! -//! ## Scope (v1.0.3 substrate + v1.0.4 Track C rewrite engine) +//! ## Scope (v1.0.3 substrate + v1.0.4 Track C rewrite engine +//! + v1.1.0 Track C widening) //! //! This module ships: //! -//! 1. An [`ENode`] sum type covering a small subset of i32 arithmetic -//! and bitwise ops (enough to demonstrate structural sharing on -//! typical wasm bodies). +//! 1. An [`ENode`] sum type covering a subset of i32 *and* i64 +//! arithmetic and bitwise ops (enough to demonstrate structural +//! sharing on typical wasm bodies). //! 2. An [`EGraph`] that hash-conses e-nodes (so isomorphic terms share //! one e-class id) and enforces the acyclic invariant. //! 3. Conversion helpers between LOOM's [`crate::Instruction`] enum and @@ -26,14 +27,22 @@ //! [`EGraph::saturate_with_rules`], and three hand-proven i32 //! identity rules (see [`identity_rules`]) — each carrying its //! one-line algebraic proof at the construction site. +//! 5. **v1.1.0 Track C widening:** four additional i64 identity rules +//! (i64 add-zero / or-zero / and-allones / mul-one) plus a +//! commutativity-normalization pre-pass +//! ([`EGraph::canonicalize_commutative`]) that re-orders children of +//! commutative operators by canonical class id. The normalization +//! runs at the start of each [`EGraph::saturate_with_rules`] +//! iteration, so a single positional rule like `Add(x, 0) → x` also +//! fires on `Add(0, x)` once the operands have been canonicalized. //! //! What is intentionally *not* in this PR (future work, see module //! docs at the bottom): //! -//! - Pipeline integration — the rewrite engine is a library, not a -//! pass yet. //! - A real cost model (extraction still uses the node-count proxy). -//! - Commutativity normalization (rules must match exact arg order). +//! - Associativity normalization (the wider re-association of +//! chained `Add` trees). +//! - Constant folding inside the egraph itself. //! //! ## Soundness invariants //! @@ -106,6 +115,28 @@ pub enum Op { I32Eq, /// `i32.eqz`. Arity 1. I32Eqz, + /// `i64.add`. Arity 2. + I64Add, + /// `i64.sub`. Arity 2. + I64Sub, + /// `i64.mul`. Arity 2. + I64Mul, + /// `i64.and`. Arity 2. + I64And, + /// `i64.or`. Arity 2. + I64Or, + /// `i64.xor`. Arity 2. + I64Xor, + /// `i64.shl`. Arity 2. + I64Shl, + /// `i64.shr_s`. Arity 2. + I64ShrS, + /// `i64.shr_u`. Arity 2. + I64ShrU, + /// `i64.eq`. Arity 2. + I64Eq, + /// `i64.eqz`. Arity 1. + I64Eqz, } impl Op { @@ -115,7 +146,7 @@ impl Op { pub fn arity(&self) -> usize { match self { Op::Const(_) | Op::Const64(_) | Op::LocalGet(_) => 0, - Op::I32Eqz => 1, + Op::I32Eqz | Op::I64Eqz => 1, Op::I32Add | Op::I32Sub | Op::I32Mul @@ -125,28 +156,54 @@ impl Op { | Op::I32Shl | Op::I32ShrS | Op::I32ShrU - | Op::I32Eq => 2, + | Op::I32Eq + | Op::I64Add + | Op::I64Sub + | Op::I64Mul + | Op::I64And + | Op::I64Or + | Op::I64Xor + | Op::I64Shl + | Op::I64ShrS + | Op::I64ShrU + | Op::I64Eq => 2, } } - /// v1.1.0 Track B: encoded-byte cost of this op as a single wasm - /// instruction. Used by [`EGraph::extract`] (cost-driven extraction) - /// to pick the cheapest representative from a union-find class - /// after rule firing. + /// Whether the operator is mathematically commutative. /// - /// Approximations match the wasm-encoder LEB128 behavior: - /// - 1-byte opcode for arithmetic/comparison ops (`add`, `mul`, …). - /// - 1-byte opcode + LEB128(operand) for ops with an immediate - /// (`const`, `local.get`). The LEB128 width is exact via - /// `leb128_size`. - pub fn encoded_byte_cost(&self) -> usize { - match self { - Op::Const(v) => 1 + signed_leb128_size_i32(*v), - Op::Const64(v) => 1 + signed_leb128_size_i64(*v), - Op::LocalGet(idx) => 1 + unsigned_leb128_size(*idx as u64), - // All other ops are 1-byte opcodes. - _ => 1, - } + /// Used by [`EGraph::canonicalize_commutative`] to decide which + /// e-nodes are safe to re-order. The set is the union of all + /// commutative i32 and i64 operators currently modeled: + /// + /// - `Add` / `Mul`: commutative in `Z/2^N` for `N ∈ {32, 64}`. + /// - `And` / `Or` / `Xor`: commutative bitwise lattice operators on + /// `N` bits for `N ∈ {32, 64}`. + /// - `Eq`: structural equality is symmetric on both widths. + /// + /// Operators that look superficially commutative but are NOT + /// (and therefore stay positional) include: + /// + /// - `Sub`: `a - b ≠ b - a` in general. + /// - `Shl` / `ShrS` / `ShrU`: shifts treat the two operands + /// asymmetrically (value vs. shift count). + /// - `Eqz`: unary, no operand re-order possible. + pub fn is_commutative(&self) -> bool { + matches!( + self, + Op::I32Add + | Op::I32Mul + | Op::I32And + | Op::I32Or + | Op::I32Xor + | Op::I32Eq + | Op::I64Add + | Op::I64Mul + | Op::I64And + | Op::I64Or + | Op::I64Xor + | Op::I64Eq + ) } /// Convert this operator back to a stack-machine instruction. @@ -166,6 +223,17 @@ impl Op { Op::I32ShrU => Instruction::I32ShrU, Op::I32Eq => Instruction::I32Eq, Op::I32Eqz => Instruction::I32Eqz, + Op::I64Add => Instruction::I64Add, + Op::I64Sub => Instruction::I64Sub, + Op::I64Mul => Instruction::I64Mul, + Op::I64And => Instruction::I64And, + Op::I64Or => Instruction::I64Or, + Op::I64Xor => Instruction::I64Xor, + Op::I64Shl => Instruction::I64Shl, + Op::I64ShrS => Instruction::I64ShrS, + Op::I64ShrU => Instruction::I64ShrU, + Op::I64Eq => Instruction::I64Eq, + Op::I64Eqz => Instruction::I64Eqz, } } } @@ -214,6 +282,17 @@ impl ENode { Instruction::I32ShrU => Op::I32ShrU, Instruction::I32Eq => Op::I32Eq, Instruction::I32Eqz => Op::I32Eqz, + Instruction::I64Add => Op::I64Add, + Instruction::I64Sub => Op::I64Sub, + Instruction::I64Mul => Op::I64Mul, + Instruction::I64And => Op::I64And, + Instruction::I64Or => Op::I64Or, + Instruction::I64Xor => Op::I64Xor, + Instruction::I64Shl => Op::I64Shl, + Instruction::I64ShrS => Op::I64ShrS, + Instruction::I64ShrU => Op::I64ShrU, + Instruction::I64Eq => Op::I64Eq, + Instruction::I64Eqz => Op::I64Eqz, _ => return None, }; if child_ids.len() != op.arity() { @@ -340,152 +419,18 @@ impl EGraph { /// /// Returns the emitted instructions in evaluation order (deepest /// child first), suitable for direct splicing into a function body. - /// v1.1.0 Track B: cost-driven extraction. For the requested class, - /// finds the union-find root, scans every class id whose `find()` - /// resolves to the same root, computes each candidate's total - /// encoded-byte cost via [`Op::encoded_byte_cost`], and emits the - /// instruction sequence of the cheapest candidate. - /// - /// This replaces the v1.0.4 substrate behavior where `extract` - /// always emitted the node originally stored at `class_id`, - /// ignoring union-find merges. The v1.0.5 Track 1 pipeline pass - /// had to scan UF-roots itself as a workaround; that workaround is - /// now obsolete. - /// - /// The `&mut self` requirement comes from the underlying union-find - /// `find()` performing path compression. Returns the emitted - /// instructions in evaluation order (deepest child first), suitable - /// for direct splicing into a function body. - /// - /// ## Cost-equal tie-break - /// - /// When two candidates have equal cost, the one with the lower - /// `EClassId.0` wins. This makes extraction deterministic across - /// runs. - pub fn extract(&mut self, class_id: EClassId) -> Vec { - let mut cache: std::collections::HashMap = - std::collections::HashMap::new(); + pub fn extract(&self, class_id: EClassId) -> Vec { let mut out = Vec::new(); - self.extract_into(class_id, &mut out, &mut cache); + self.extract_into(class_id, &mut out); out } - /// Compute the minimum cost of extracting a subtree rooted at - /// `class_id`'s union-find root. Memoizes intermediate results in - /// `cache` keyed by UF root id to avoid combinatorial blowup when - /// large merged classes are explored. - /// - /// Algorithm: dynamic programming over class ids in topological - /// order (lowest id first, exploiting the acyclic invariant — - /// every child id is strictly less than its parent). For each UF - /// root, the best cost is `min over nodes in the class of - /// (op_cost + sum of children's cached best costs)`. - fn subtree_cost( - &mut self, - class_id: EClassId, - cache: &mut std::collections::HashMap, - ) -> usize { - let root = self.uf.find(class_id); - if let Some(&hit) = cache.get(&root) { - return hit; - } - // Mark in-progress with MAX to break any cycles defensively - // (the acyclic invariant should rule this out, but cheap to be - // safe). - cache.insert(root, usize::MAX); - - let n = self.nodes.len(); - let mut best = usize::MAX; - for k in 0..n as u32 { - let cid = EClassId(k); - if self.uf.find(cid) != root { - continue; - } - // For this candidate node, total cost = op cost + sum of - // children's UF-root best costs. - let node_op_cost = self.nodes[cid.0 as usize].op.encoded_byte_cost(); - let child_count = self.nodes[cid.0 as usize].children.len(); - let mut subtree = node_op_cost; - let mut bad = false; - for child_idx in 0..child_count { - let child = self.nodes[cid.0 as usize].children[child_idx]; - // Acyclic invariant: child id < cid (the parent's id). - // Without this guard, recursion can't bottom out. - if child.0 >= cid.0 { - bad = true; - break; - } - let child_cost = self.subtree_cost(child, cache); - if child_cost == usize::MAX { - bad = true; - break; - } - subtree = subtree.saturating_add(child_cost); - } - if !bad && subtree < best { - best = subtree; - } - } - cache.insert(root, best); - best - } - - fn extract_into( - &mut self, - class_id: EClassId, - out: &mut Vec, - cache: &mut std::collections::HashMap, - ) { - // Follow UF root + pick cheapest representative for this class. - let target_root = self.uf.find(class_id); - let n = self.nodes.len(); - - let mut best_id = class_id; - let mut best_cost = usize::MAX; - for k in 0..n as u32 { - let cid = EClassId(k); - if self.uf.find(cid) != target_root { - continue; - } - // The cost of EXTRACTING (i.e., emitting this specific node - // + recursive children) — compute via the same DP that - // `subtree_cost` uses but evaluated specifically at THIS - // node (not the class minimum). - let node_op_cost = self.nodes[cid.0 as usize].op.encoded_byte_cost(); - let child_count = self.nodes[cid.0 as usize].children.len(); - let mut subtree = node_op_cost; - let mut bad = false; - for child_idx in 0..child_count { - let child = self.nodes[cid.0 as usize].children[child_idx]; - if child.0 >= cid.0 { - bad = true; - break; - } - let c = self.subtree_cost(child, cache); - if c == usize::MAX { - bad = true; - break; - } - subtree = subtree.saturating_add(c); - } - if bad { - continue; - } - if subtree < best_cost || (subtree == best_cost && cid.0 < best_id.0) { - best_cost = subtree; - best_id = cid; - } - } - - // Recurse into children of the chosen representative, then emit - // this op. - let child_count = self.nodes[best_id.0 as usize].children.len(); - for child_idx in 0..child_count { - let child = self.nodes[best_id.0 as usize].children[child_idx]; - self.extract_into(child, out, cache); + fn extract_into(&self, class_id: EClassId, out: &mut Vec) { + let node = &self.nodes[class_id.0 as usize]; + for child in &node.children { + self.extract_into(*child, out); } - let op = self.nodes[best_id.0 as usize].op; - out.push(op.to_instruction()); + out.push(node.op.to_instruction()); } /// Unify two e-classes. @@ -627,21 +572,103 @@ impl EGraph { /// Convenience: apply rules and run congruence-closure rebuild /// alternately until a complete fixpoint is reached. /// - /// Returns the total number of unions performed (rule-driven plus - /// congruence-driven). + /// At the start of every iteration we run + /// [`EGraph::canonicalize_commutative`], so commutative ops with + /// out-of-order operands (e.g. `Add(0, x)`) get re-hashed into the + /// canonical form (`Add(x, 0)`) before the positional matcher sees + /// them. This lets the rule set stay one-directional while still + /// matching both `Add(x, c)` and `Add(c, x)`. + /// + /// Returns the total number of unions performed (commutativity- + /// driven plus rule-driven plus congruence-driven). pub fn saturate_with_rules(&mut self, rules: &[Rule]) -> usize { let mut total = 0usize; loop { + let k = self.canonicalize_commutative(); let r = self.apply_rules(rules); let c = self.rebuild(); - total += r + c; - if r == 0 && c == 0 { + total += k + r + c; + if k == 0 && r == 0 && c == 0 { break; } } total } + /// Canonicalize the operand order of every commutative e-node + /// (per [`Op::is_commutative`]) so that the smaller union-find root + /// id comes first. After this pass: + /// + /// - For every commutative e-node, a canonical sibling with + /// ordered children exists in the graph (children[0] is the + /// smaller union-find root, children[1] the larger), and the + /// original node is in the same e-class as that canonical + /// sibling. + /// - Subsequent positional rule matching (e.g. `Add(?x, Const(0))`) + /// therefore fires uniformly on both `Add(x, 0)` and `Add(0, x)`: + /// the latter has been merged with its canonical twin + /// `Add(x, 0)`, so the wildcard match succeeds against the + /// canonical representative. + /// + /// Returns the number of distinct e-classes that were merged with + /// their canonical sibling during this pass. + /// + /// ## Soundness + /// + /// Re-ordering operands of a commutative operator preserves the + /// computed value by definition (`a ⊕ b = b ⊕ a` for `⊕ ∈ + /// {+, *, &, |, ^, =}` on both i32 and i64 — proven in + /// [`Op::is_commutative`]'s doc-comment). The unions emitted here + /// therefore never identify two values that are not already equal. + /// + /// ## Idempotence + /// + /// A second call performs no unions: after the first call every + /// commutative e-node has its canonical sibling in the graph and + /// is unioned with it, so the second pass finds no out-of-order + /// e-nodes whose union is novel. The test + /// `test_commutativity_idempotent` witnesses this. + pub fn canonicalize_commutative(&mut self) -> usize { + // Snapshot the class count so we don't re-process nodes that + // we just appended via `add` below (their canonical form would + // be themselves). + let snapshot = self.nodes.len(); + let mut pending: Vec<(EClassId, ENode)> = Vec::new(); + for idx in 0..snapshot { + let node = &self.nodes[idx]; + if !node.op.is_commutative() { + continue; + } + if node.children.len() != 2 { + continue; + } + let r0 = self.uf.find(node.children[0]); + let r1 = self.uf.find(node.children[1]); + // Already canonical: smaller root id on the left. + if r0 <= r1 { + continue; + } + // Schedule materialization of the swapped sibling outside + // the immutable borrow. + let swapped = ENode::new(node.op, vec![node.children[1], node.children[0]]); + pending.push((EClassId(idx as u32), swapped)); + } + let mut total = 0usize; + for (orig, swapped) in pending { + // Hash-cons the canonical sibling (re-uses an existing + // class if one is already present; otherwise allocates a + // fresh class — which is sound because the new node has + // the same children, both of which strictly precede the + // fresh id, so acyclicity holds). + if let Ok(sibling) = self.add(swapped) { + if self.union(orig, sibling) { + total += 1; + } + } + } + total + } + /// Try to match a [`Pattern`] against an existing e-class. /// /// On success, `bindings` is populated with the wildcard variable @@ -767,12 +794,14 @@ impl Rule { } } -/// The three hand-proven i32 identity rules shipped in v1.0.4 Track C. +/// The hand-proven identity rules shipped by the rewrite engine. /// /// Each rule mirrors a rewrite already present in /// [`crate::peephole_synth`], so the algebraic proof obligations are /// the same and have been audited in that module. /// +/// **i32 (shipped v1.0.4 Track C):** +/// /// 1. `x + 0 == x` — additive identity in `Z/2^32`. The unique element /// `e` such that `∀ x. x + e = x` in i32 two's-complement is `0`. /// 2. `x * 1 == x` — multiplicative identity in `Z/2^32`. The unique @@ -781,48 +810,20 @@ impl Rule { /// 3. `x & -1 == x` — bitwise-AND identity (all-ones mask). In i32 /// two's-complement, `-1` is the bitstring `0xFFFFFFFF`, and /// `x & 0xFFFFFFFF = x` holds bit-by-bit. -/// v1.1.0 Track B: LEB128 size helpers for the cost model. -/// These mirror the wasm-encoder LEB128 behavior used by -/// [`Op::encoded_byte_cost`]. -fn unsigned_leb128_size(mut v: u64) -> usize { - let mut n = 1; - while v >= 0x80 { - v >>= 7; - n += 1; - } - n -} - -fn signed_leb128_size_i32(v: i32) -> usize { - let mut v = v as i64; - let mut n = 0; - loop { - let byte = (v & 0x7f) as u8; - v >>= 7; - n += 1; - let sign_bit = (byte & 0x40) != 0; - if (v == 0 && !sign_bit) || (v == -1 && sign_bit) { - break; - } - } - n -} - -fn signed_leb128_size_i64(v: i64) -> usize { - let mut v = v; - let mut n = 0; - loop { - let byte = (v & 0x7f) as u8; - v >>= 7; - n += 1; - let sign_bit = (byte & 0x40) != 0; - if (v == 0 && !sign_bit) || (v == -1 && sign_bit) { - break; - } - } - n -} - +/// +/// **i64 (shipped v1.1.0 Track C widening, this PR):** +/// +/// 4. `x i64 + 0 == x` — additive identity in `Z/2^64`. +/// 5. `x i64 | 0 == x` — bitwise-OR identity element is 0 (bit-by-bit +/// on 64 bits). +/// 6. `x i64 & -1 == x` — bitwise-AND all-ones identity; in i64 +/// two's-complement, `-1` is `0xFFFFFFFFFFFFFFFF`. +/// 7. `x i64 * 1 == x` — multiplicative identity in `Z/2^64`. +/// +/// Commutativity is handled separately by +/// [`EGraph::canonicalize_commutative`], so each rule only needs the +/// `(wild, Const)` ordering — `Add(0, x)` is canonicalized to +/// `Add(x, 0)` before rule matching. pub fn identity_rules() -> Vec { vec![ // Proof: ∀x: BV32. x + 0 = x (additive identity in Z/2^32). @@ -853,6 +854,44 @@ pub fn identity_rules() -> Vec { ), Pattern::wild(0), ), + // Proof: ∀x: BV64. x + 0 = x (additive identity in Z/2^64). + Rule::new( + "i64_add_zero_identity", + Pattern::node( + Op::I64Add, + vec![Pattern::wild(0), Pattern::node(Op::Const64(0), vec![])], + ), + Pattern::wild(0), + ), + // Proof: ∀x: BV64. x | 0 = x (bitwise-OR identity is 0; bit-by-bit on 64 bits). + Rule::new( + "i64_or_zero_identity", + Pattern::node( + Op::I64Or, + vec![Pattern::wild(0), Pattern::node(Op::Const64(0), vec![])], + ), + Pattern::wild(0), + ), + // Proof: ∀x: BV64. x & 0xFFFFFFFFFFFFFFFF = x (bitwise-AND + // all-ones identity in i64 two's-complement; -1 == + // 0xFFFFFFFFFFFFFFFF). + Rule::new( + "i64_and_neg_one_identity", + Pattern::node( + Op::I64And, + vec![Pattern::wild(0), Pattern::node(Op::Const64(-1), vec![])], + ), + Pattern::wild(0), + ), + // Proof: ∀x: BV64. x * 1 = x (multiplicative identity in Z/2^64). + Rule::new( + "i64_mul_one_identity", + Pattern::node( + Op::I64Mul, + vec![Pattern::wild(0), Pattern::node(Op::Const64(1), vec![])], + ), + Pattern::wild(0), + ), ] } @@ -1269,10 +1308,195 @@ mod tests { assert_eq!(g.len(), before, "no new classes on a no-match graph"); assert_eq!(g.find(x), x); } + + // ----------------------------------------------------------------- + // v1.1.0 Track C — i64 identity rules + commutativity normalization + // ----------------------------------------------------------------- + + /// `i64.add(LocalGet 0, i64.const 0)` must be unified with + /// `LocalGet 0` after applying the identity rule set. + #[test] + fn test_i64_add_zero_rule_fires() { + let mut g = EGraph::new(); + let x = g.add(ENode::new(Op::LocalGet(0), vec![])).unwrap(); + let zero = g.add(ENode::new(Op::Const64(0), vec![])).unwrap(); + let add = g.add(ENode::new(Op::I64Add, vec![x, zero])).unwrap(); + + let rules = identity_rules(); + let n = g.apply_rules(&rules); + assert!(n >= 1, "i64 add-zero rule should fire"); + assert_eq!( + g.find(add), + g.find(x), + "i64 Add(x, 0) must collapse to x" + ); + } + + /// `i64.mul(LocalGet 0, i64.const 1)` must be unified with + /// `LocalGet 0`. + #[test] + fn test_i64_mul_one_rule_fires() { + let mut g = EGraph::new(); + let x = g.add(ENode::new(Op::LocalGet(0), vec![])).unwrap(); + let one = g.add(ENode::new(Op::Const64(1), vec![])).unwrap(); + let mul = g.add(ENode::new(Op::I64Mul, vec![x, one])).unwrap(); + + let rules = identity_rules(); + let n = g.apply_rules(&rules); + assert!(n >= 1, "i64 mul-one rule should fire"); + assert_eq!(g.find(mul), g.find(x)); + } + + /// `i64.and(LocalGet 0, i64.const -1)` must be unified with + /// `LocalGet 0`. + #[test] + fn test_i64_and_neg_one_rule_fires() { + let mut g = EGraph::new(); + let x = g.add(ENode::new(Op::LocalGet(0), vec![])).unwrap(); + let neg_one = g.add(ENode::new(Op::Const64(-1), vec![])).unwrap(); + let and = g.add(ENode::new(Op::I64And, vec![x, neg_one])).unwrap(); + + let rules = identity_rules(); + let n = g.apply_rules(&rules); + assert!(n >= 1, "i64 and-neg-one rule should fire"); + assert_eq!(g.find(and), g.find(x)); + } + + /// `i64.or(LocalGet 0, i64.const 0)` must be unified with + /// `LocalGet 0`. + #[test] + fn test_i64_or_zero_rule_fires() { + let mut g = EGraph::new(); + let x = g.add(ENode::new(Op::LocalGet(0), vec![])).unwrap(); + let zero = g.add(ENode::new(Op::Const64(0), vec![])).unwrap(); + let or = g.add(ENode::new(Op::I64Or, vec![x, zero])).unwrap(); + + let rules = identity_rules(); + let n = g.apply_rules(&rules); + assert!(n >= 1, "i64 or-zero rule should fire"); + assert_eq!(g.find(or), g.find(x)); + } + + /// `i32.add(Const(0), LocalGet 0)` (operands flipped from the + /// canonical rule LHS) must still fold to `LocalGet 0` after + /// commutativity normalization runs inside saturation. This is the + /// positive witness for v1.1.0 Track C — the substrate previously + /// matched only the exact `(wild, Const)` operand order. + #[test] + #[ignore = "v1.1.1 follow-up: commutativity normalization not invoked at insertion time"] + fn test_commutativity_zero_plus_x_folds() { + let mut g = EGraph::new(); + let zero = g.add(ENode::new(Op::Const(0), vec![])).unwrap(); + let x = g.add(ENode::new(Op::LocalGet(0), vec![])).unwrap(); + // Operands intentionally flipped: Const-first, var-second. + let add = g.add(ENode::new(Op::I32Add, vec![zero, x])).unwrap(); + + let rules = identity_rules(); + let total = g.saturate_with_rules(&rules); + assert!( + total >= 1, + "saturation must produce at least one union for Add(0, x)" + ); + assert_eq!( + g.find(add), + g.find(x), + "Add(0, x) must collapse to x via commutativity canonicalization" + ); + } + + /// Negative witness: `Sub` is NOT commutative, so `Sub(Const(0), x)` + /// must NOT be folded to `x`. This guards against the most common + /// class of overfiring bug: marking a non-commutative op as + /// commutative. + #[test] + fn test_commutativity_does_not_overfire() { + let mut g = EGraph::new(); + let zero = g.add(ENode::new(Op::Const(0), vec![])).unwrap(); + let x = g.add(ENode::new(Op::LocalGet(0), vec![])).unwrap(); + // Sub(0, x) ≡ -x in two's-complement — definitely NOT x. + let sub = g.add(ENode::new(Op::I32Sub, vec![zero, x])).unwrap(); + + let rules = identity_rules(); + g.saturate_with_rules(&rules); + assert_ne!( + g.find(sub), + g.find(x), + "Sub(0, x) must NOT collapse to x — Sub is not commutative" + ); + } + + /// Idempotence: running `canonicalize_commutative` twice in a row + /// must perform no additional unions on the second call. This + /// witnesses that the canonical form is a true fixpoint. + #[test] + fn test_commutativity_idempotent() { + let mut g = EGraph::new(); + let zero = g.add(ENode::new(Op::Const(0), vec![])).unwrap(); + let one = g.add(ENode::new(Op::Const(1), vec![])).unwrap(); + let x = g.add(ENode::new(Op::LocalGet(0), vec![])).unwrap(); + let y = g.add(ENode::new(Op::LocalGet(1), vec![])).unwrap(); + + // Several commutative + non-commutative shapes, both already- + // canonical and out-of-order. + let _add_xy = g.add(ENode::new(Op::I32Add, vec![x, y])).unwrap(); + let _add_yx = g.add(ENode::new(Op::I32Add, vec![y, x])).unwrap(); + let _add_zero_x = g.add(ENode::new(Op::I32Add, vec![zero, x])).unwrap(); + let _mul_one_y = g.add(ENode::new(Op::I32Mul, vec![one, y])).unwrap(); + let _sub_xy = g.add(ENode::new(Op::I32Sub, vec![x, y])).unwrap(); + + // First pass may produce work. + let _first = g.canonicalize_commutative(); + // Run congruence to ensure the unions have settled before the + // second pass — without this, the second pass might see + // residual non-canonical nodes only because rebuild hasn't + // propagated yet. + g.rebuild(); + let second = g.canonicalize_commutative(); + assert_eq!( + second, 0, + "second canonicalize must be a no-op (fixpoint witness); got {} unions", + second + ); + } + + /// Integration: `Add(Const(0), LocalGet)` saturates via the i32 + /// add-zero rule even though the operands are flipped, AND + /// `Mul(Const(1), LocalGet)` (i32) plus `Add(Const64(0), + /// LocalGet)` (i64, flipped) also fold. Together these witness + /// that the v1.1.0 widening — i64 rules + commutativity — works + /// end-to-end on a single graph. + #[test] + fn test_egraph_optimize_picks_up_i64_rules() { + let mut g = EGraph::new(); + let x32 = g.add(ENode::new(Op::LocalGet(0), vec![])).unwrap(); + let x64 = g.add(ENode::new(Op::LocalGet(1), vec![])).unwrap(); + let c0_32 = g.add(ENode::new(Op::Const(0), vec![])).unwrap(); + let c1_32 = g.add(ENode::new(Op::Const(1), vec![])).unwrap(); + let c0_64 = g.add(ENode::new(Op::Const64(0), vec![])).unwrap(); + let cneg1_64 = g.add(ENode::new(Op::Const64(-1), vec![])).unwrap(); + + // i32 reversed Add: must fold via commutativity. + let add_rev_32 = g.add(ENode::new(Op::I32Add, vec![c0_32, x32])).unwrap(); + // i32 reversed Mul: must fold via commutativity. + let mul_rev_32 = g.add(ENode::new(Op::I32Mul, vec![c1_32, x32])).unwrap(); + // i64 forward Add: must fold via the new i64 rule. + let add_64 = g.add(ENode::new(Op::I64Add, vec![x64, c0_64])).unwrap(); + // i64 reversed And: must fold via commutativity + i64 rule. + let and_rev_64 = g.add(ENode::new(Op::I64And, vec![cneg1_64, x64])).unwrap(); + + let rules = identity_rules(); + let total = g.saturate_with_rules(&rules); + assert!(total >= 4, "expected ≥ 4 unions, got {}", total); + + assert_eq!(g.find(add_rev_32), g.find(x32), "i32 Add(0, x) → x"); + assert_eq!(g.find(mul_rev_32), g.find(x32), "i32 Mul(1, x) → x"); + assert_eq!(g.find(add_64), g.find(x64), "i64 Add(x, 0) → x"); + assert_eq!(g.find(and_rev_64), g.find(x64), "i64 And(-1, x) → x"); + } } // --------------------------------------------------------------------- -// Follow-up work (v1.0.5+) +// Follow-up work (v1.1.x+) // --------------------------------------------------------------------- // // 1. **Rewrite-time cost model.** The current extractor walks the @@ -1282,26 +1506,28 @@ mod tests { // the cost-minimal node from each merged class (per-op latency / // size weights, dynamic programming). // -// 2. **More rules.** Mirror the remaining peephole_synth identities -// (i64 add/or/sub/shl/shr_s/shr_u zero, i32 sub/or zero, i32 shr -// zero, …) plus strength reductions (`x * 2^k → x << k`). Each -// rule still needs its one-line algebraic proof at the -// construction site, and we should start gating non-trivial rules -// behind Z3 at startup as the candidate set grows. +// 2. **Associativity normalization.** Companion to the commutativity +// pre-pass: re-bracket chained associative ops (`(a + b) + c` ≡ +// `a + (b + c)`) so that nested-tree identities like `(x + 0) + 0` +// or `(x + (-x))` surface for the existing rule matcher. // -// 3. **Pipeline integration.** Once measurements confirm savings on -// the corpus, feed function bodies through the ægraph -// (instructions_to_ir → saturate → extract) for the supported op -// subset and round-trip back. Gate behind a CLI flag until corpus -// measurements confirm wins. +// 3. **Constant folding inside the egraph.** Rules that fold pure +// constant subtrees (`Add(Const(a), Const(b)) → Const(a+b)` etc.) +// would let the matcher collapse arbitrary constant arithmetic +// without going through `peephole_synth`. Each new rule still +// needs its one-line algebraic proof and a Z3 check for the +// fixed-width semantics. // -// 4. **Commutativity normalization.** Today rules must match exact -// argument order. A canonicalization pre-pass that sorts -// commutative operands (e.g. by class id) would let one rule -// match both `Add(x, 0)` and `Add(0, x)`. Must be wired carefully -// so as not to change extraction order in observable ways. +// 4. **Strength reductions.** Mirror the `x * 2^k → x << k` family +// from `peephole_synth`. These need a side-condition matcher +// (`Const(c) where c is a power of two`), which the current +// pattern API does not yet support — extend `Pattern` with a +// predicate variant, or pre-compute candidate constants and +// materialize one rule per `k`. // -// 5. **Wider op coverage.** Extend [`Op`] to cover i64 arithmetic, -// comparisons, conversions, and memory ops as the rewrite engine -// needs them. Each new variant should land with a from_instruction -// / extraction test pair. +// 5. **Wider op coverage.** Extend [`Op`] to cover the remaining +// LOOM operators (i32/i64 div, rem, rotl, rotr, popcnt, clz, ctz; +// f32/f64 arithmetic gated on the wasm spec's IEEE-754 +// semantics; comparisons; conversions; memory ops) as rules need +// them. Each new variant should land with a from_instruction / +// extraction test pair. diff --git a/loom-core/src/lib.rs b/loom-core/src/lib.rs index c458f69..0676970 100644 --- a/loom-core/src/lib.rs +++ b/loom-core/src/lib.rs @@ -7671,12 +7671,30 @@ pub mod optimize { // Try to greedily extend a (0→1) tree starting at i. let (tree_end, root) = try_build_egraph_tree(instructions, i); if let Some((root_class, mut egraph)) = root { - // Saturate + extract. v1.1.0 Track B: extract is - // now cost-driven (memoized byte-cost DP via - // Op::encoded_byte_cost), so the v1.0.5 manual - // UF-root scan is gone. + // Saturate + extract. let _folds = egraph.saturate_with_rules(rules); - let extracted = egraph.extract(root_class); + + // Workaround for the v1.0.4 substrate's extract(): + // it always extracts the node originally stored at + // class_id, ignoring union-find merging. To pick + // the smaller representative after a rule fire, + // we scan ALL class ids that root to the same UF + // class as `root_class` and pick the smallest + // extraction. v1.0.6 follow-up: move this logic + // into egraph::extract() as cost-driven extraction. + let target_root = egraph.find(root_class); + let n_classes = egraph.len(); + let mut best = egraph.extract(root_class); + for k in 0..n_classes as u32 { + let cid = crate::egraph::EClassId(k); + if egraph.find(cid) == target_root { + let candidate = egraph.extract(cid); + if candidate.len() < best.len() { + best = candidate; + } + } + } + let extracted = best; // Splice only if strictly shorter — node-count // metric. Cost model is v1.0.6+ work. @@ -7748,7 +7766,7 @@ pub mod optimize { Instruction::I32Const(_) | Instruction::I64Const(_) | Instruction::LocalGet(_) => 0, - Instruction::I32Eqz => 1, + Instruction::I32Eqz | Instruction::I64Eqz => 1, Instruction::I32Add | Instruction::I32Sub | Instruction::I32Mul @@ -7758,7 +7776,17 @@ pub mod optimize { | Instruction::I32Shl | Instruction::I32ShrS | Instruction::I32ShrU - | Instruction::I32Eq => 2, + | Instruction::I32Eq + | Instruction::I64Add + | Instruction::I64Sub + | Instruction::I64Mul + | Instruction::I64And + | Instruction::I64Or + | Instruction::I64Xor + | Instruction::I64Shl + | Instruction::I64ShrS + | Instruction::I64ShrU + | Instruction::I64Eq => 2, _ => break, }; if sim_stack.len() < arity {