From daef175af758e568ab39694168bdc5c4c5443133 Mon Sep 17 00:00:00 2001 From: stringhandler Date: Thu, 28 May 2026 14:12:53 +0200 Subject: [PATCH 1/2] feat: add enum declarations and multi-arm enum match expressions Introduces `enum` with explicit u8 discriminants and N-arm `match` over enum variants, desugared into `jet::eq_8` comparison chains at the AST level. Missing witness values are zero-filled before Simplicity witness population; a post-prune check errors if any zero-filled witness appears on a surviving (non-pruned) branch. --- examples/last_will.inherit.wit | 10 +- examples/last_will.simf | 22 +- src/ast.rs | 587 ++++++++++++++++++++++++++++++++- src/lexer.rs | 42 ++- src/lib.rs | 188 ++++++++++- src/named.rs | 62 ++++ src/parse.rs | 443 +++++++++++++++++++++---- src/value.rs | 28 ++ src/witness.rs | 60 +++- test-data/last_will.json | 2 +- 10 files changed, 1352 insertions(+), 92 deletions(-) diff --git a/examples/last_will.inherit.wit b/examples/last_will.inherit.wit index 16752030..88b46ab7 100644 --- a/examples/last_will.inherit.wit +++ b/examples/last_will.inherit.wit @@ -1,6 +1,10 @@ { - "INHERIT_OR_NOT": { - "value": "Left(0x755201bb62b0a8b8d18fd12fc02951ea3998ba42bfc6664daaf8a0d2298cad43cdc21358c7c82f37654275dc2fea8c858adbe97bac92828b498a5a237004db6f)", - "type": "Either>" + "ACTION": { + "value": "1", + "type": "u8" + }, + "INHERITOR_SIG": { + "value": "0x755201bb62b0a8b8d18fd12fc02951ea3998ba42bfc6664daaf8a0d2298cad43cdc21358c7c82f37654275dc2fea8c858adbe97bac92828b498a5a237004db6f", + "type": "Signature" } } diff --git a/examples/last_will.simf b/examples/last_will.simf index 9790a1cf..aab2bf73 100644 --- a/examples/last_will.simf +++ b/examples/last_will.simf @@ -40,12 +40,22 @@ fn refresh_spend(hot_sig: Signature) { recursive_covenant(); } +enum Action { + Inherit=1, + ColdSpend =2, + HotSpend =3, +} + fn main() { - match witness::INHERIT_OR_NOT { - Left(inheritor_sig: Signature) => inherit_spend(inheritor_sig), - Right(cold_or_hot: Either) => match cold_or_hot { - Left(cold_sig: Signature) => cold_spend(cold_sig), - Right(hot_sig: Signature) => refresh_spend(hot_sig), - }, + match witness::ACTION { + Action::Inherit => { + let inheritor_sig: Signature = witness::INHERITOR_SIG; + inherit_spend(inheritor_sig)} , + Action::ColdSpend => { + let cold_sig: Signature = witness::COLD_SIG; + cold_spend(cold_sig) }, + Action::HotSpend => { + let hot_sig: Signature = witness::HOT_SIG; + refresh_spend(hot_sig) }, } } diff --git a/src/ast.rs b/src/ast.rs index f7e64549..5ef4c72a 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,5 +1,5 @@ use std::collections::hash_map::Entry; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::num::NonZeroUsize; use std::sync::Arc; @@ -18,7 +18,7 @@ use crate::str::{AliasName, FunctionName, Identifier, ModuleName, SymbolName, Wi use crate::types::{ AliasedType, ResolvedType, StructuralType, TypeConstructible, TypeDeconstructible, UIntType, }; -use crate::value::{UIntValue, Value}; +use crate::value::{UIntValue, Value, ValueConstructible}; use crate::witness::{Parameters, WitnessTypes}; use crate::{impl_eq_hash, parse}; @@ -558,6 +558,62 @@ struct ModuleScope { submodules: HashMap, } +/// A single enum variant after analysis: its name and u8 discriminant, without source span. +#[derive(Clone, Debug, Eq, PartialEq)] +struct ResolvedEnumVariant { + name: Identifier, + discriminant: u8, +} + +/// The resolved definition of an enum as stored in [`Scope`]: +/// a list of [`ResolvedEnumVariant`]s in declaration order. +#[derive(Clone, Debug, Eq, PartialEq)] +struct EnumBinding { + variants: Arc<[ResolvedEnumVariant]>, +} + +impl EnumBinding { + fn new(variants: Arc<[ResolvedEnumVariant]>) -> Self { + Self { variants } + } + + fn variants(&self) -> &[ResolvedEnumVariant] { + &self.variants + } + + fn contains_variant(&self, name: &Identifier) -> bool { + self.variants.iter().any(|v| &v.name == name) + } +} + +/// A single enum variant after analysis: its name and u8 discriminant, without source span. +#[derive(Clone, Debug, Eq, PartialEq)] +struct ResolvedEnumVariant { + name: Identifier, + discriminant: u8, +} + +/// The resolved definition of an enum as stored in [`Scope`]: +/// a list of [`ResolvedEnumVariant`]s in declaration order. +#[derive(Clone, Debug, Eq, PartialEq)] +struct EnumBinding { + variants: Arc<[ResolvedEnumVariant]>, +} + +impl EnumBinding { + fn new(variants: Arc<[ResolvedEnumVariant]>) -> Self { + Self { variants } + } + + fn variants(&self) -> &[ResolvedEnumVariant] { + &self.variants + } + + fn contains_variant(&self, name: &Identifier) -> bool { + self.variants.iter().any(|v| &v.name == name) + } +} + /// Scope for generating the abstract syntax tree. /// /// The scope is used for: @@ -575,6 +631,7 @@ struct Scope { /// Block-level variable scopes. Push on block enter, pop on block exit. variables: Vec>, + enums: HashMap, EnumBinding>, parameters: HashMap, witnesses: HashMap, is_main: bool, @@ -597,6 +654,7 @@ impl Scope { module_path: Vec::new(), root: ModuleScope::default(), variables: Vec::new(), + enums: HashMap::new(), parameters: HashMap::new(), witnesses: HashMap::new(), is_main: false, @@ -935,6 +993,24 @@ impl Scope { Ok(()) } + pub fn insert_enum( + &mut self, + name: AliasName, + variants: Arc<[ResolvedEnumVariant]>, + ) -> Result<(), Error> { + let plug = (name.clone(), self.file_id); + if self.enums.contains_key(&plug) { + return Err(Error::RedefinedAlias { name }); + } + self.enums.insert(plug, EnumBinding::new(variants)); + Ok(()) + } + + pub fn get_enum(&self, name: &AliasName) -> Option<&EnumBinding> { + let plug = (name.clone(), self.file_id); + self.enums.get(&plug) + } + /// Insert a parameter into the global map. /// /// ## Errors @@ -1134,7 +1210,61 @@ impl AbstractSyntaxTree for Item { Ok(Self::Module(analyzed_children)) } parse::Item::Ignored => Ok(Self::Ignored), - } + parse::Item::EnumDeclaration(decl) => { + scope.file_id = decl.file_id(); + let n = decl.variants().len(); + if n < 2 { + return Err(Error::Grammar { + msg: format!("enum '{}' must have at least 2 variants", decl.name()), + }) + .with_span(decl); + } + let mut sorted: Vec<&parse::EnumVariant> = decl.variants().iter().collect(); + sorted.sort_by_key(|v| v.discriminant()); + for w in sorted.windows(2) { + if w[0].discriminant() == w[1].discriminant() { + return Err(Error::Grammar { + msg: format!( + "enum '{}' has duplicate discriminant {}", + decl.name(), + w[0].discriminant() + ), + }) + .with_span(decl); + } + } + let mut seen_names = HashSet::new(); + for v in decl.variants() { + if !seen_names.insert(v.name()) { + return Err(Error::Grammar { + msg: format!( + "enum '{}' has duplicate variant name '{}'", + decl.name(), + v.name() + ), + }) + .with_span(decl); + } + } + let variants: Arc<[ResolvedEnumVariant]> = sorted + .iter() + .map(|v| ResolvedEnumVariant { + name: v.name().clone(), + discriminant: v.discriminant(), + }) + .collect(); + scope + .insert_alias(decl.name().clone(), AliasedType::from(UIntType::U8)) + .with_span(decl)?; + scope + .insert_enum(decl.name().clone(), variants) + .with_span(decl)?; + Ok(Self::TypeAlias) + } + }; + + scope.file_id = previous_file_id; + res } } @@ -1447,6 +1577,75 @@ impl AbstractSyntaxTree for SingleExpression { parse::SingleExpressionInner::Match(match_) => { Match::analyze(match_, ty, scope).map(SingleExpressionInner::Match)? } + parse::SingleExpressionInner::EnumMatch(enum_match) => { + let arms = enum_match.arms(); + let span = *enum_match.span(); + if arms.is_empty() { + return Err(Error::Grammar { + msg: "enum match has no arms".to_string(), + }) + .with_span(span); + } + let enum_name = match arms[0].pattern() { + MatchPattern::EnumVariant(name, _) => name.clone(), + _ => unreachable!("EnumMatch arms have EnumVariant patterns"), + }; + let binding = scope + .get_enum(&enum_name) + .ok_or_else(|| Error::UndefinedAlias { + name: enum_name.clone(), + }) + .with_span(span)?; + let mut arm_map: HashMap<&Identifier, &parse::Expression> = HashMap::new(); + for arm in arms { + let MatchPattern::EnumVariant(arm_enum_name, variant) = arm.pattern() else { + unreachable!("EnumMatch arms have EnumVariant patterns") + }; + if arm_enum_name != &enum_name { + return Err(Error::Grammar { + msg: format!( + "all match arms must use the same enum; expected '{}', found '{}'", + enum_name, arm_enum_name + ), + }) + .with_span(span); + } + if !binding.contains_variant(variant) { + return Err(Error::Grammar { + msg: format!( + "variant '{}' is not defined in enum '{}'", + variant, enum_name + ), + }) + .with_span(span); + } + if arm_map.insert(variant, arm.expression()).is_some() { + return Err(Error::Grammar { + msg: format!("duplicate arm for variant '{}'", variant), + }) + .with_span(span); + } + } + if arm_map.len() != binding.variants().len() { + return Err(Error::Grammar { + msg: format!( + "enum match on '{}' must cover all {} variants", + enum_name, + binding.variants().len() + ), + }) + .with_span(span); + } + let ordered_arms: Vec<(&parse::Expression, u8)> = binding + .variants() + .iter() + .map(|v| (arm_map[&v.name], v.discriminant)) + .collect(); + let u8_ty = ResolvedType::from(UIntType::U8); + let scrutinee = + Expression::analyze(enum_match.scrutinee(), &u8_ty, scope).map(Arc::new)?; + desugar_enum_arms_u8(&ordered_arms, scrutinee, ty, scope, span)? + } }; Ok(Self { @@ -1808,6 +2007,152 @@ impl AbstractSyntaxTree for Match { } } +/// Desugar an N-arm enum match (u8 discriminant) into a `jet::eq_8` comparison chain. +fn desugar_enum_arms_u8( + arms: &[(&parse::Expression, u8)], + scrutinee: Arc, + expected_ty: &ResolvedType, + scope: &mut Scope, + span: Span, +) -> Result { + debug_assert!(arms.len() >= 2); + + let u8_ty = ResolvedType::from(UIntType::U8); + + // Bind the scrutinee to a fresh variable to avoid witness-reuse errors. + let disc_ident = Identifier::from_str_unchecked("__disc_"); + + scope.push_scope(); + scope.insert_variable(disc_ident.clone(), u8_ty.clone()); + + let analyzed_arms: Vec<(Arc, u8)> = arms + .iter() + .map(|(e, disc)| { + scope.push_scope(); + let result = + Expression::analyze(e, expected_ty, scope).map(|expr| (Arc::new(expr), *disc)); + scope.pop_scope(); + result + }) + .collect::, _>>()?; + + let chain = build_u8_chain(&disc_ident, &analyzed_arms, expected_ty, &u8_ty, span); + scope.pop_scope(); + + // Wrap in block: { let __disc_N: u8 = scrutinee; } + let chain_expr = Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: chain, + ty: expected_ty.clone(), + span, + }), + ty: expected_ty.clone(), + span, + }); + let assign_stmt = Statement::Assignment(Assignment { + pattern: Pattern::Identifier(disc_ident), + expression: (*scrutinee).clone(), + span, + }); + Ok(SingleExpressionInner::Expression(Arc::new(Expression { + inner: ExpressionInner::Block(Arc::from([assign_stmt]), Some(chain_expr)), + ty: expected_ty.clone(), + span, + }))) +} + +/// Build a nested bool-`Match` chain for u8 discriminant dispatch. +/// +/// Every variant, including the last, is guarded by an `eq8` comparison. +/// A `panic!()` on the final false branch ensures that any undeclared +/// discriminant value causes the script to fail rather than silently +/// executing the last arm. +/// +/// `if eq8(disc, d[0]) { arms[0] } else if eq8(disc, d[1]) { arms[1] } ... else if eq8(disc, d[N-1]) { arms[N-1] } else { panic!() }` +fn build_u8_chain( + disc_ident: &Identifier, + arms: &[(Arc, u8)], + expected_ty: &ResolvedType, + u8_ty: &ResolvedType, + span: Span, +) -> SingleExpressionInner { + debug_assert!(!arms.is_empty()); + + let (arm_expr, discriminant) = &arms[0]; + let disc_var = Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: SingleExpressionInner::Variable(disc_ident.clone()), + ty: u8_ty.clone(), + span, + }), + ty: u8_ty.clone(), + span, + }); + let const_expr = Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: SingleExpressionInner::Constant(Value::u8(*discriminant)), + ty: u8_ty.clone(), + span, + }), + ty: u8_ty.clone(), + span, + }); + let eq8_expr = Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: SingleExpressionInner::Call(Call { + name: CallName::Jet(Box::new(Elements::Eq8)), + args: Arc::from([(*disc_var).clone(), (*const_expr).clone()]), + span, + }), + ty: ResolvedType::boolean(), + span, + }), + ty: ResolvedType::boolean(), + span, + }); + + let false_branch = if arms.len() == 1 { + // Last arm: an undeclared discriminant must not silently execute any arm. + Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: SingleExpressionInner::Call(Call { + name: CallName::Panic, + args: Arc::from([]), + span, + }), + ty: expected_ty.clone(), + span, + }), + ty: expected_ty.clone(), + span, + }) + } else { + let rest_inner = build_u8_chain(disc_ident, &arms[1..], expected_ty, u8_ty, span); + Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: rest_inner, + ty: expected_ty.clone(), + span, + }), + ty: expected_ty.clone(), + span, + }) + }; + + SingleExpressionInner::Match(Match { + scrutinee: eq8_expr, + left: MatchArm { + pattern: MatchPattern::False, + expression: false_branch, + }, + right: MatchArm { + pattern: MatchPattern::True, + expression: arm_expr.clone(), + }, + span, + }) +} + impl AsRef for Assignment { fn as_ref(&self) -> &Span { &self.span @@ -1838,6 +2183,242 @@ impl AsRef for Match { } } +#[cfg(test)] +mod enum_tests { + use super::{ElementsJetHinter, Program}; + use crate::driver::tests::setup_graph; + use crate::error::ErrorCollector; + + fn analyze(src: &str) -> Result<(), String> { + let (graph, _ids, _dir) = setup_graph(vec![("main.simf", src)]); + let mut handler = ErrorCollector::new(); + let driver_prog = graph + .linearize_and_build(&mut handler) + .unwrap() + .expect("driver build should succeed"); + Program::analyze(&driver_prog, Box::new(ElementsJetHinter::new())) + .map(|_| ()) + .map_err(|e| e.to_string()) + } + + #[test] + fn enum_declaration_registers_type_alias() { + let result = analyze( + "enum Color { Red = 1, Green = 2 } + fn main() { let _x: Color = witness::C; }", + ); + assert!( + result.is_ok(), + "enum should register as type alias: {result:?}" + ); + } + + #[test] + fn enum_match_on_function_return() { + let result = analyze( + "enum Dir { Left = 1, Right = 2 } + fn wrap(d: Dir) -> Dir { d } + fn main() { + match wrap(witness::D) { + Dir::Left => assert!(jet::eq_32(0, 0)), + Dir::Right => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "enum match on function return should analyze: {result:?}" + ); + } + + #[test] + fn enum_match_2_variants_desugars() { + let result = analyze( + "enum Coin { Heads = 1, Tails = 2 } + fn main() { + match witness::C { + Coin::Heads => assert!(jet::eq_32(0, 0)), + Coin::Tails => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "2-variant enum match should analyze: {result:?}" + ); + } + + #[test] + fn enum_match_3_variants_desugars() { + let result = analyze( + "enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + Path::C => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "3-variant enum match should analyze: {result:?}" + ); + } + + #[test] + fn enum_match_arms_sorted_by_discriminant() { + // Arms in reverse discriminant order should still compile correctly. + let result = analyze( + "enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::P { + Path::C => assert!(jet::eq_32(0, 0)), + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "arms in any order should compile: {result:?}" + ); + } + + #[test] + fn enum_too_few_variants_is_error() { + let result = analyze("enum Bad { Only = 1 } fn main() {}"); + assert!(result.is_err(), "single-variant enum should error"); + assert!( + result.unwrap_err().contains("at least 2 variants"), + "expected 'at least 2 variants' in error" + ); + } + + #[test] + fn enum_duplicate_discriminant_is_error() { + let result = analyze("enum Bad { A = 1, B = 1 } fn main() {}"); + assert!(result.is_err(), "duplicate discriminant should error"); + assert!( + result.unwrap_err().contains("duplicate discriminant"), + "expected 'duplicate discriminant' in error" + ); + } + + #[test] + fn enum_duplicate_variant_name_is_error() { + let result = analyze("enum Bad { A = 1, A = 2 } fn main() {}"); + assert!(result.is_err(), "duplicate variant name should error"); + assert!( + result.unwrap_err().contains("duplicate variant name"), + "expected 'duplicate variant name' in error" + ); + } + + #[test] + fn enum_duplicate_name_is_error() { + use crate::error::ErrorCollector; + let (graph, _ids, _dir) = setup_graph(vec![( + "main.simf", + "enum Color { Red = 1, Green = 2 } + enum Color { Blue = 1, Yellow = 2 } + fn main() {}", + )]); + let mut handler = ErrorCollector::new(); + let program_option = graph.linearize_and_build(&mut handler).unwrap(); + assert!( + program_option.is_none(), + "duplicate enum name should cause build failure" + ); + } + + #[test] + fn enum_match_missing_arm_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "missing arm should error"); + assert!( + result.unwrap_err().contains("must cover all"), + "expected 'must cover all' in error" + ); + } + + #[test] + fn enum_match_unknown_variant_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::X => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "unknown variant should error"); + assert!( + result.unwrap_err().contains("not defined in enum"), + "expected 'not defined in enum' in error" + ); + } + + #[test] + fn enum_match_duplicate_arm_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::A => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "duplicate arm should error"); + assert!( + result.unwrap_err().contains("duplicate arm"), + "expected 'duplicate arm' in error" + ); + } + + #[test] + fn enum_match_mixed_enum_names_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2 } + enum Other { A = 1, B = 2 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Other::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "mixed enum names should error"); + assert!( + result.unwrap_err().contains("same enum"), + "expected 'same enum' in error" + ); + } + + #[test] + fn enum_match_undefined_enum_is_error() { + let result = analyze( + "fn main() { + match witness::P { + Unknown::A => assert!(jet::eq_32(0, 0)), + Unknown::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "undefined enum should error"); + } +} + #[cfg(test)] mod scope_resolution_tests { use super::{ElementsJetHinter, Program}; diff --git a/src/lexer.rs b/src/lexer.rs index 06d63adc..fb1b729a 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -21,6 +21,7 @@ pub enum Token<'src> { Mod, Const, Match, + Enum, Crate, // Control symbols @@ -80,6 +81,7 @@ impl<'src> fmt::Display for Token<'src> { Token::Mod => write!(f, "mod"), Token::Const => write!(f, "const"), Token::Match => write!(f, "match"), + Token::Enum => write!(f, "enum"), Token::Crate => write!(f, "{}", CRATE_STR), Token::Arrow => write!(f, "->"), @@ -156,6 +158,7 @@ pub fn lexer<'src>( "mod" => Token::Mod, "const" => Token::Const, "match" => Token::Match, + "enum" => Token::Enum, CRATE_STR => Token::Crate, "true" => Token::Bool(true), "false" => Token::Bool(false), @@ -259,7 +262,8 @@ pub fn lex<'src>(input: &'src str) -> (Option>, Vec assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + } + } + "#; + // Select variant A via its u8 discriminant. + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("PATH"), Value::u8(1)); + TestCase::program_text(Cow::Borrowed(src)) + .with_witness_values(WitnessValues::from(map)) + .assert_run_success(); + } + + #[test] + fn enum_match_3_variants() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" + enum Path { A = 0, B = 2, C = 5 } + fn main() { + match witness::PATH { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + Path::C => assert!(jet::eq_32(0, 0)), + } + } + "#; + // Select variant C via its u8 discriminant. + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("PATH"), Value::u8(5)); + TestCase::program_text(Cow::Borrowed(src)) + .with_witness_values(WitnessValues::from(map)) + .assert_run_success(); + } + + #[test] + fn enum_match_function_return() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" + enum Dir { Left = 1, Right = 2 } + fn wrap(d: Dir) -> Dir { d } + fn main() { + match wrap(witness::D) { + Dir::Left => assert!(jet::eq_32(0, 0)), + Dir::Right => assert!(jet::eq_32(0, 0)), + } + } + "#; + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("D"), Value::u8(1)); + TestCase::program_text(Cow::Borrowed(src)) + .with_witness_values(WitnessValues::from(map)) + .assert_run_success(); + } + + #[test] + fn enum_match_invalid_discriminant_fails() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" + enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::PATH { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + Path::C => assert!(jet::eq_32(0, 0)), + } + } + "#; + // Discriminant 0 is not declared in the enum; the script must fail. + for bad in [0u8, 4, 99, 255] { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("PATH"), Value::u8(bad)); + let result = TestCase::program_text(Cow::Borrowed(src)) + .with_witness_values(WitnessValues::from(map)) + .run(); + assert!( + result.is_err(), + "discriminant {bad} is not declared; execution should fail but succeeded" + ); + } + } + + #[test] + fn missing_witness_on_live_branch_errors() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" +enum Branch { A = 1, B = 2 } +fn main() { + match witness::SELECTOR { + Branch::A => assert!(jet::is_zero_32(witness::A)), + Branch::B => assert!(jet::is_zero_32(witness::B)), + } +} +"#; + let env = crate::dummy_env::dummy(); + + // SELECTOR = 1 (Branch::A) → branch A taken; B is missing but pruned → satisfy OK + { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("SELECTOR"), Value::u8(1)); + map.insert(WitnessName::from_str_unchecked("A"), Value::u32(0)); + let compiled = CompiledProgram::new( + src, + Arguments::default(), + false, + Box::new(ElementsJetHinter::new()), + ) + .unwrap(); + compiled + .satisfy_with_env(WitnessValues::from(map), Some(&env)) + .expect("B is on a pruned branch; satisfy should succeed"); + } + + // SELECTOR = 2 (Branch::B) → branch B taken; A is missing but pruned → satisfy OK + { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("SELECTOR"), Value::u8(2)); + map.insert(WitnessName::from_str_unchecked("B"), Value::u32(0)); + let compiled = CompiledProgram::new( + src, + Arguments::default(), + false, + Box::new(ElementsJetHinter::new()), + ) + .unwrap(); + compiled + .satisfy_with_env(WitnessValues::from(map), Some(&env)) + .expect("A is on a pruned branch; satisfy should succeed"); + } + + // SELECTOR = 2 (Branch::B) → branch B taken; B is missing and live → satisfy errors + { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("SELECTOR"), Value::u8(2)); + // B is intentionally not provided + let compiled = CompiledProgram::new( + src, + Arguments::default(), + false, + Box::new(ElementsJetHinter::new()), + ) + .unwrap(); + let err = compiled + .satisfy_with_env(WitnessValues::from(map), Some(&env)) + .expect_err("B is on the executed branch and missing; satisfy should fail"); + assert!( + err.contains('B'), + "error message should mention witness B, got: {err}" + ); + } + } + #[test] #[cfg(feature = "serde")] fn hodl_vault() { diff --git a/src/named.rs b/src/named.rs index 9de1b6e7..c4ad36d3 100644 --- a/src/named.rs +++ b/src/named.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::sync::Arc; use simplicity::dag::{InternalSharing, PostOrderIterItem}; @@ -243,6 +244,67 @@ pub fn populate_witnesses( node.convert::(&mut populator) } +/// Walk the `commit` tree and the `pruned` redeem tree in parallel, checking that +/// no zero-filled witness (tracked in `zero_filled`) appears on a non-pruned branch. +/// +/// Pruned branches are indicated by `Fail` nodes in the pruned tree. When a `Case` +/// node is pruned to `AssertL` or `AssertR`, only the surviving child is recursed into. +pub fn check_surviving_witnesses( + commit: &CommitNode, + pruned: &Arc, + zero_filled: &HashSet, +) -> Result<(), String> { + match (commit.inner(), pruned.inner()) { + // Pruned branch or unreachable fail node — no witnesses to check + (_, Inner::Fail(_)) | (Inner::Fail(_), _) => Ok(()), + // Witness node on a live branch — error if it was zero-filled + (Inner::Witness(name), Inner::Witness(_)) => { + if zero_filled.contains(name) { + Err(format!( + "Witness `{name}` is used on the executed branch but has no assigned value" + )) + } else { + Ok(()) + } + } + // Leaf nodes with no witness children + (Inner::Iden, _) | (Inner::Unit, _) | (Inner::Jet(_), _) | (Inner::Word(_), _) => Ok(()), + // Single-child nodes — recurse into the child + (Inner::InjL(cc), Inner::InjL(cp)) + | (Inner::InjR(cc), Inner::InjR(cp)) + | (Inner::Take(cc), Inner::Take(cp)) + | (Inner::Drop(cc), Inner::Drop(cp)) => check_surviving_witnesses(cc, cp, zero_filled), + // Assert nodes — one live child, one CMR; recurse into the live child + (Inner::AssertL(cc, _), Inner::AssertL(cp, _)) + | (Inner::AssertR(_, cc), Inner::AssertR(_, cp)) => { + check_surviving_witnesses(cc, cp, zero_filled) + } + // Two-child nodes — recurse into both + (Inner::Comp(cl, cr), Inner::Comp(pl, pr)) | (Inner::Pair(cl, cr), Inner::Pair(pl, pr)) => { + check_surviving_witnesses(cl, pl, zero_filled)?; + check_surviving_witnesses(cr, pr, zero_filled) + } + // Case: both branches live + (Inner::Case(cl, cr), Inner::Case(pl, pr)) => { + check_surviving_witnesses(cl, pl, zero_filled)?; + check_surviving_witnesses(cr, pr, zero_filled) + } + // Case pruned to AssertL: only left branch survived + (Inner::Case(cl, _), Inner::AssertL(pl, _)) => { + check_surviving_witnesses(cl, pl, zero_filled) + } + // Case pruned to AssertR: only right branch survived + (Inner::Case(_, cr), Inner::AssertR(_, pr)) => { + check_surviving_witnesses(cr, pr, zero_filled) + } + // Disconnect — not used in SimplicityHL; handle defensively + (Inner::Disconnect(cc, _), Inner::Disconnect(cp, _)) => { + check_surviving_witnesses(cc, cp, zero_filled) + } + _ => unreachable!("unexpected structural mismatch between commit and pruned trees"), + } +} + // This awkward construction is required by rust-simplicity to implement WitnessConstructible // for Node>. See // https://docs.rs/simplicity-lang/latest/simplicity/node/trait.WitnessConstructible.html#foreign-impls diff --git a/src/parse.rs b/src/parse.rs index 0f4a035b..6d1b265c 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -65,6 +65,8 @@ pub enum Item { /// An import declaration (e.g., `use math::add`) that brings another /// [`Item`] into the current scope. Use(UseDecl), + /// An enum declaration. + EnumDeclaration(EnumDeclaration), /// A module containing a collection of nested [`Item`]. Module(Module), /// A placeholder used exclusively for error recovery during parsing. @@ -435,6 +437,91 @@ impl TypeAlias { impl_eq_hash!(TypeAlias; name, ty); +/// A single variant in an enum declaration. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct EnumVariant { + name: Identifier, + discriminant: u8, + span: Span, +} + +impl EnumVariant { + pub fn name(&self) -> &Identifier { + &self.name + } + + pub fn discriminant(&self) -> u8 { + self.discriminant + } +} + +impl AsRef for EnumVariant { + fn as_ref(&self) -> &Span { + &self.span + } +} + +/// An enum declaration. +#[derive(Clone, Debug)] +pub struct EnumDeclaration { + file_id: usize, + visibility: Visibility, + name: AliasName, + variants: Arc<[EnumVariant]>, + span: Span, +} + +impl EnumDeclaration { + pub fn file_id(&self) -> usize { + self.file_id + } + + pub fn set_file_id(&mut self, file_id: usize) { + self.file_id = file_id; + } + + pub fn visibility(&self) -> &Visibility { + &self.visibility + } + + pub fn name(&self) -> &AliasName { + &self.name + } + + pub fn variants(&self) -> &[EnumVariant] { + &self.variants + } +} + +impl_eq_hash!(EnumDeclaration; name, variants); + +impl AsRef for EnumDeclaration { + fn as_ref(&self) -> &Span { + &self.span + } +} + +#[cfg(feature = "arbitrary")] +impl<'a> arbitrary::Arbitrary<'a> for EnumDeclaration { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let file_id = u.int_in_range(0..=3)?; + let visibility = Visibility::arbitrary(u)?; + let name = AliasName::arbitrary(u)?; + let len = u.int_in_range(2..=8)?; + let variants = (0..len) + .map(|_| EnumVariant::arbitrary(u)) + .collect::>>()?; + Ok(Self { + file_id, + visibility, + name, + variants, + span: Span::DUMMY, + }) + } +} + /// An expression is something that returns a value. #[derive(Clone, Debug)] pub struct Expression { @@ -537,6 +624,8 @@ pub enum SingleExpressionInner { Expression(Arc), /// Match expression over a sum type Match(Match), + /// Match expression over a named enum type + EnumMatch(EnumMatch), /// Tuple wrapper expression Tuple(Arc<[Expression]>), /// Array wrapper expression @@ -592,6 +681,30 @@ impl Match { impl_eq_hash!(Match; scrutinee, left, right); +/// Match expression over a named enum type (N arms, N ≥ 2). +#[derive(Clone, Debug)] +pub struct EnumMatch { + scrutinee: Arc, + arms: Arc<[MatchArm]>, + span: Span, +} + +impl EnumMatch { + pub fn scrutinee(&self) -> &Expression { + &self.scrutinee + } + + pub fn arms(&self) -> &[MatchArm] { + &self.arms + } + + pub fn span(&self) -> &Span { + &self.span + } +} + +impl_eq_hash!(EnumMatch; scrutinee, arms); + /// Arm of a match expression. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct MatchArm { @@ -627,6 +740,8 @@ pub enum MatchPattern { False, /// Match true value (no binding). True, + /// Match a named enum variant (no payload binding). + EnumVariant(AliasName, Identifier), } impl MatchPattern { @@ -636,7 +751,10 @@ impl MatchPattern { MatchPattern::Left(i, _) | MatchPattern::Right(i, _) | MatchPattern::Some(i, _) => { Some(i) } - MatchPattern::None | MatchPattern::False | MatchPattern::True => None, + MatchPattern::None + | MatchPattern::False + | MatchPattern::True + | MatchPattern::EnumVariant(..) => None, } } @@ -646,7 +764,10 @@ impl MatchPattern { MatchPattern::Left(i, ty) | MatchPattern::Right(i, ty) | MatchPattern::Some(i, ty) => { Some((i, ty)) } - MatchPattern::None | MatchPattern::False | MatchPattern::True => None, + MatchPattern::None + | MatchPattern::False + | MatchPattern::True + | MatchPattern::EnumVariant(..) => None, } } } @@ -712,6 +833,7 @@ impl fmt::Display for Item { Self::TypeAlias(alias) => write!(f, "{alias}"), Self::Function(function) => write!(f, "{function}"), Self::Use(use_declaration) => write!(f, "{use_declaration}"), + Self::EnumDeclaration(decl) => write!(f, "{decl}"), Self::Module(module) => write!(f, "{module}"), Self::Ignored => Ok(()), } @@ -836,6 +958,7 @@ pub enum ExprTree<'a> { Single(&'a SingleExpression), Call(&'a Call), Match(&'a Match), + EnumMatch(&'a EnumMatch), } impl TreeLike for ExprTree<'_> { @@ -876,6 +999,7 @@ impl TreeLike for ExprTree<'_> { | S::Expression(l) => Tree::Unary(Self::Expression(l)), S::Call(call) => Tree::Unary(Self::Call(call)), S::Match(match_) => Tree::Unary(Self::Match(match_)), + S::EnumMatch(enum_match) => Tree::Unary(Self::EnumMatch(enum_match)), S::Tuple(elements) | S::Array(elements) | S::List(elements) => { Tree::Nary(elements.iter().map(Self::Expression).collect()) } @@ -886,6 +1010,16 @@ impl TreeLike for ExprTree<'_> { Self::Expression(match_.left().expression()), Self::Expression(match_.right().expression()), ])), + Self::EnumMatch(enum_match) => Tree::Nary( + std::iter::once(Self::Expression(enum_match.scrutinee())) + .chain( + enum_match + .arms() + .iter() + .map(|arm| Self::Expression(arm.expression())), + ) + .collect(), + ), } } } @@ -951,7 +1085,7 @@ impl fmt::Display for ExprTree<'_> { write!(f, ")")?; } }, - S::Call(..) | S::Match(..) => {} + S::Call(..) | S::Match(..) | S::EnumMatch(..) => {} S::Tuple(tuple) => { if data.n_children_yielded == 0 { write!(f, "(")?; @@ -1002,6 +1136,18 @@ impl fmt::Display for ExprTree<'_> { write!(f, ",\n}}")?; } }, + Self::EnumMatch(enum_match) => { + let n = data.n_children_yielded; + if n == 0 { + write!(f, "match ")?; + } else if n == 1 { + write!(f, "{{\n{} => ", enum_match.arms()[0].pattern())?; + } else if n <= enum_match.arms().len() { + write!(f, ",\n{} => ", enum_match.arms()[n - 1].pattern())?; + } else { + write!(f, ",\n}}")?; + } + } } } @@ -1074,7 +1220,24 @@ impl fmt::Display for MatchPattern { MatchPattern::Some(i, ty) => write!(f, "Some({i}: {ty})"), MatchPattern::False => write!(f, "false"), MatchPattern::True => write!(f, "true"), + MatchPattern::EnumVariant(enum_name, variant) => write!(f, "{enum_name}::{variant}"), + } + } +} + +impl fmt::Display for EnumDeclaration { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}enum {} {{", self.visibility(), self.name())?; + for variant in self.variants() { + write!(f, " {} = {},", variant.name(), variant.discriminant())?; } + write!(f, " }}") + } +} + +impl fmt::Display for EnumMatch { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", ExprTree::EnumMatch(self)) } } @@ -1419,9 +1582,16 @@ impl ChumskyParse for Item { let use_parser = UseDecl::parser().map(Item::Use); // Lazy item here + let enum_parser = EnumDeclaration::parser().map(Item::EnumDeclaration); let mod_parser = Module::parser_with_items(item).map(Item::Module); - choice((func_parser, use_parser, type_parser, mod_parser)) + choice(( + func_parser, + use_parser, + type_parser, + enum_parser, + mod_parser, + )) }) } } @@ -1838,6 +2008,62 @@ impl ChumskyParse for TypeAlias { } } +impl ChumskyParse for EnumDeclaration { + fn parser<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone + where + I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, + { + let visibility = just(Token::Pub) + .to(Visibility::Public) + .or_not() + .map(Option::unwrap_or_default); + + let discriminant = just(Token::Eq) + .ignore_then(select! { Token::DecLiteral(d) => d }) + .try_map(|d, span| { + d.as_inner().parse::().map_err(|_| { + RichError::new( + Error::Grammar { + msg: format!( + "enum discriminant '{}' is out of range (must be 0-255)", + d.as_inner() + ), + }, + span, + ) + }) + }); + + let variant = + Identifier::parser() + .then(discriminant) + .map_with(|(name, discriminant), e| EnumVariant { + name, + discriminant, + span: e.span(), + }); + + let variants = variant + .separated_by(just(Token::Comma)) + .allow_trailing() + .collect::>() + .delimited_by(just(Token::LBrace), just(Token::RBrace)) + .map(Arc::from); + + visibility + .then_ignore(just(Token::Enum)) + .then(AliasName::parser()) + .then(variants) + .map_with(|((visibility, name), variants), e| Self { + file_id: MAIN_MODULE, + visibility, + name, + variants, + span: e.span(), + }) + } +} + impl ChumskyParse for Expression { fn parser<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone where @@ -1959,7 +2185,7 @@ impl SingleExpression { let call = Call::parser(expr.clone()).map(SingleExpressionInner::Call); - let match_expr = Match::parser(expr.clone()).map(SingleExpressionInner::Match); + let match_expr = match_expr_parser(expr.clone()); let variable = Identifier::parser().map(SingleExpressionInner::Variable); @@ -2014,75 +2240,60 @@ impl ChumskyParse for MatchPattern { } } -impl MatchArm { - fn parser<'tokens, 'src: 'tokens, I, E>( - expr: E, - ) -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone - where - I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, - E: Parser<'tokens, I, Expression, ParseError<'src>> + Clone + 'tokens, - { - MatchPattern::parser() - .then_ignore(just(Token::FatArrow)) - .then(expr.map(Arc::new)) - .then(just(Token::Comma).or_not()) - .validate(|((pattern, expression), comma), e, emitter| { - let is_block = matches!(expression.as_ref().inner, ExpressionInner::Block(_, _)); - - if !is_block && comma.is_none() { - emitter.emit( - Error::Grammar { - msg: "Missing ',' after a match arm that isn't block expression" - .to_string(), - } - .with_span(e.span()), - ); - } - - Self { - pattern, - expression, - } - }) - } -} - -impl Match { - fn parser<'tokens, 'src: 'tokens, I, E>( - expr: E, - ) -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone - where - I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, - E: Parser<'tokens, I, Expression, ParseError<'src>> + Clone + 'tokens, - { - let scrutinee = expr.clone().map(Arc::new); - - let arm_recovery = any() - .filter(|t| !matches!(t, Token::Comma | Token::RBrace)) - .ignored() - .or(nested_delimiters( - Token::LBrace, - Token::RBrace, - [ - (Token::LParen, Token::RParen), - (Token::LBracket, Token::RBracket), - ], - |_| (), - ) - .ignored()) - .repeated() - .map_with(|(), _| None); +/// Parser for `match` expressions. +/// +/// Handles both binary match (exactly 2 arms: Left/Right, None/Some, false/true) and enum match +/// (2+ arms using `EnumName::Variant` patterns). Dispatches to [`Match`] or [`EnumMatch`] based +/// on the patterns found. +fn match_expr_parser<'tokens, 'src: 'tokens, I, E>( + expr: E, +) -> impl Parser<'tokens, I, SingleExpressionInner, ParseError<'src>> + Clone +where + I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, + E: Parser<'tokens, I, Expression, ParseError<'src>> + Clone + 'tokens, +{ + let scrutinee = expr.clone().map(Arc::new); - let arm_parser = MatchArm::parser(expr.clone()) - .map(Some) - .recover_with(via_parser(arm_recovery.clone())); + // Enum variant pattern: `EnumName::VariantName`. + // Binary keywords are excluded so choice() works without backtracking: + // when the ident is Left/Right/Some/None the select! guard fails without consuming the token. + let enum_variant_pattern = + select! { Token::Ident(name) if name != "Left" && name != "Right" && name != "Some" && name != "None" => AliasName::from_str_unchecked(name) } + .then_ignore(just(Token::DoubleColon)) + .then(select! { Token::Ident(v) => Identifier::from_str_unchecked(v) }) + .map(|(enum_name, variant)| MatchPattern::EnumVariant(enum_name, variant)); + + let combined_pattern = choice((enum_variant_pattern, MatchPattern::parser())); + + // No recover_with here: repeated() stops naturally when arm_parser fails. + // Outer delimited_with_recovery handles the block-level recovery. + let arm_parser = combined_pattern + .then_ignore(just(Token::FatArrow)) + .then(expr.clone().map(Arc::new)) + .then(just(Token::Comma).or_not()) + .validate(|((pattern, expression), comma), e, emitter| { + let is_block = matches!(expression.as_ref().inner, ExpressionInner::Block(_, _)); + if !is_block && comma.is_none() { + emitter.emit( + Error::Grammar { + msg: "Missing ',' after a match arm that isn't block expression" + .to_string(), + } + .with_span(e.span()), + ); + } + MatchArm { + pattern, + expression, + } + }); - let arms = delimited_with_recovery( - arm_parser.clone().then(arm_parser.clone()), - Token::LBrace, - Token::RBrace, - |_| (None, None), - ); + let arms = delimited_with_recovery( + arm_parser.repeated().collect::>(), + Token::LBrace, + Token::RBrace, + |_| vec![], + ); just(Token::Match) .ignore_then(scrutinee) @@ -2616,4 +2827,92 @@ mod test { assert_eq!(program.to_string(), format!("{input}\n")); } } + + fn parse_item(input: &str) -> Item { + let program = parse::Program::parse_from_str(input).expect("parsing should succeed"); + program.items().first().expect("expected one item").clone() + } + + #[test] + fn test_enum_declaration_basic() { + let item = parse_item("enum Path { Inherit = 1, ColdSpend = 2, RefreshSpend = 3, }"); + let Item::EnumDeclaration(decl) = item else { + panic!("expected EnumDeclaration, got {item:?}"); + }; + assert_eq!(decl.name().as_inner(), "Path"); + assert_eq!(decl.variants().len(), 3); + assert_eq!(decl.variants()[0].name().as_inner(), "Inherit"); + assert_eq!(decl.variants()[0].discriminant(), 1); + assert_eq!(decl.variants()[1].name().as_inner(), "ColdSpend"); + assert_eq!(decl.variants()[1].discriminant(), 2); + assert_eq!(decl.variants()[2].name().as_inner(), "RefreshSpend"); + assert_eq!(decl.variants()[2].discriminant(), 3); + } + + #[test] + fn test_enum_declaration_pub() { + let item = parse_item("pub enum Color { Red = 0, Green = 1, Blue = 2, }"); + let Item::EnumDeclaration(decl) = item else { + panic!("expected EnumDeclaration"); + }; + assert_eq!(decl.visibility(), &Visibility::Public); + assert_eq!(decl.name().as_inner(), "Color"); + } + + #[test] + fn test_enum_declaration_display_round_trip() { + let input = "enum Path { Inherit = 1, ColdSpend = 2, RefreshSpend = 3, }"; + let item = parse_item(input); + let Item::EnumDeclaration(decl) = item else { + panic!("expected EnumDeclaration"); + }; + assert_eq!( + decl.to_string(), + "enum Path { Inherit = 1, ColdSpend = 2, RefreshSpend = 3, }" + ); + } + + #[test] + fn test_enum_match_parses() { + let input = "fn main() { match witness::PATH { Path::Inherit => 0, Path::ColdSpend => 1, Path::RefreshSpend => 2, } }"; + let source = SourceFile::anonymous(Arc::from(input)); + let mut errors = ErrorCollector::new(); + let program = Program::parse_from_str_with_errors(source, &mut errors); + assert!(program.is_some(), "should parse without errors"); + assert!( + !errors.has_errors(), + "unexpected errors: {}", + ErrorCollector::to_string(&errors) + ); + } + + #[test] + fn test_enum_match_produces_enum_match_node() { + let input = + "fn main() { match witness::PATH { Path::Inherit => 0, Path::ColdSpend => 1, } }"; + let program = parse::Program::parse_from_str(input).expect("parsing should succeed"); + // Walk the tree looking for an EnumMatch node + let has_enum_match = program.items().iter().any(|item| { + if let Item::Function(f) = item { + format!("{f}").contains("Path::Inherit") + } else { + false + } + }); + assert!(has_enum_match, "expected EnumMatch in the parse tree"); + } + + #[test] + fn test_binary_match_still_works_after_enum_parser_change() { + let input = "fn main() { let x: bool = true; match x { true => 1, false => 0, } }"; + let source = SourceFile::anonymous(Arc::from(input)); + let mut errors = ErrorCollector::new(); + let program = Program::parse_from_str_with_errors(source, &mut errors); + assert!(program.is_some(), "binary match should still parse"); + assert!( + !errors.has_errors(), + "unexpected errors: {}", + ErrorCollector::to_string(&errors) + ); + } } diff --git a/src/value.rs b/src/value.rs index 1ccb38bd..47df6a4a 100644 --- a/src/value.rs +++ b/src/value.rs @@ -648,6 +648,34 @@ impl Value { }; Ok(ret) } + + /// Create a zero value of the given type. + /// + /// For integers, this is 0. For sum types, this is `Left(zero)`. For options, this is `None`. + /// For tuples and arrays, each element is zero. For lists, this is the empty list. + pub fn zero(ty: &ResolvedType) -> Self { + match ty.as_inner() { + TypeInner::Boolean => Self::from(false), + TypeInner::UInt(uint_ty) => match uint_ty { + UIntType::U1 => Self::u1(0), + UIntType::U2 => Self::u2(0), + UIntType::U4 => Self::u4(0), + UIntType::U8 => Self::u8(0), + UIntType::U16 => Self::u16(0), + UIntType::U32 => Self::u32(0), + UIntType::U64 => Self::u64(0), + UIntType::U128 => Self::u128(0), + UIntType::U256 => Self::u256(U256::from_byte_array([0u8; 32])), + }, + TypeInner::Either(left, right) => Self::left(Self::zero(left), (**right).clone()), + TypeInner::Option(inner) => Self::none((**inner).clone()), + TypeInner::Tuple(elements) => Self::tuple(elements.iter().map(|e| Self::zero(e))), + TypeInner::Array(el_ty, size) => { + Self::array((0..*size).map(|_| Self::zero(el_ty)), (**el_ty).clone()) + } + TypeInner::List(el_ty, bound) => Self::list([], (**el_ty).clone(), *bound), + } + } } impl Value { diff --git a/src/witness.rs b/src/witness.rs index ae8b0581..3d0178ba 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::sync::Arc; @@ -128,6 +128,23 @@ impl WitnessValues { Ok(()) } + + /// Return a copy of these witness values with zero values inserted for any witness declared + /// in `types` that has no assigned value. Witnesses already present are unchanged. + /// + /// This is used before populating Simplicity witness nodes: all nodes must be filled, even + /// those on branches that will be pruned away and never executed. + pub fn fill_missing(&self, types: &WitnessTypes) -> (Self, HashSet) { + let mut map: HashMap = (*self.0).clone(); + let mut zero_filled = HashSet::new(); + for (name, ty) in types.iter() { + if !map.contains_key(name) { + map.insert(name.shallow_clone(), Value::zero(ty)); + zero_filled.insert(name.shallow_clone()); + } + } + (Self::from(map), zero_filled) + } } impl ParseFromStr for ResolvedType { @@ -216,7 +233,7 @@ mod tests { use crate::ast::ElementsJetHinter; use crate::parse::ParseFromStr; use crate::value::ValueConstructible; - use crate::{ast, parse, CompiledProgram, SatisfiedProgram}; + use crate::{ast, parse, CompiledProgram, ResolvedType, SatisfiedProgram}; #[test] fn witness_reuse() { @@ -282,6 +299,45 @@ fn main() { } } + #[test] + fn fill_missing_zero_fills_and_tracks_missing_witnesses() { + let ty = ResolvedType::parse_from_str("u32").unwrap(); + let witness_types = WitnessTypes::from(HashMap::from([ + (WitnessName::from_str_unchecked("A"), ty.clone()), + (WitnessName::from_str_unchecked("B"), ty.clone()), + (WitnessName::from_str_unchecked("C"), ty.clone()), + ])); + + // A is explicitly provided with value zero (same value fill_missing would insert). + // B and C are not provided at all. + let provided = WitnessValues::from(HashMap::from([( + WitnessName::from_str_unchecked("A"), + Value::u32(0), + )])); + + let (filled, zero_filled) = provided.fill_missing(&witness_types); + + // Explicitly-provided witnesses must NOT be tracked as zero-filled, + // even when their value happens to be zero. + assert!( + !zero_filled.contains(&WitnessName::from_str_unchecked("A")), + "A was explicitly provided; must not appear in zero_filled" + ); + // Missing witnesses must be tracked so check_surviving_witnesses can error. + assert!( + zero_filled.contains(&WitnessName::from_str_unchecked("B")), + "B was not provided; must appear in zero_filled" + ); + assert!( + zero_filled.contains(&WitnessName::from_str_unchecked("C")), + "C was not provided; must appear in zero_filled" + ); + // All three must now have values in the filled map. + assert!(filled.get(&WitnessName::from_str_unchecked("A")).is_some()); + assert!(filled.get(&WitnessName::from_str_unchecked("B")).is_some()); + assert!(filled.get(&WitnessName::from_str_unchecked("C")).is_some()); + } + #[test] fn witness_to_string() { let witness = WitnessValues::from(HashMap::from([ diff --git a/test-data/last_will.json b/test-data/last_will.json index 3dace1a8..e54d5914 100644 --- a/test-data/last_will.json +++ b/test-data/last_will.json @@ -1,4 +1,4 @@ { - "program": "5wnQKEGJsWVABAmKSEGCrynMGLpUF69BbvwQFoAuY+y1ngQJfqSPabfWRZ9K3F2jdRYYBitLzfMz987l3WKtAxSudDhYOBTf5tlucUbKz5QK2LfAvMA1kChBh+DHCpJAk4cziqISK6EzABFXCwYvClhPYGFQusJfripGQssOAVt34AhgGJAoSQbgJxuBig/FJwqFobGHNddy8HoTqejIHGcv8bcleUZT57KmW1Vp7LXaMUR4qMQ4YBiE3n41BAOBgcOJFAGQOJwuLAGkHHAHHpBiBQbkHacYEf5RB7X1tMEVAbpXAfNhcd45LjO88p6usCblccJ7lByDCchhcRcOA4GJxgBwcGIGlafkwigGSWMMQSTRPhfidUim1MchFg2+ZsIYB8RO84Db5ByMCcj0kCT4YnM/BVazZBMsdgY/lS0WYcYfsNRJVmhtHQmf/PVrNEOe4wYBisgAAAAIGkKhambcTmCIv9QHGkTTAYXN78l9PRKwkaP7L+QgG2hgCZadW734oMAxC4AwcwLvfahR91ofRxdIEhoraXiTCljMruIwlAG9G26fy7ABhgGL4cEObxOhhI4BnviN4uejwVYdGCizvg8pDe+f7r9U2pQklHLAwDhITlckiyAAAAAAUKvhqObQQ6CxT1VyVCCLZUfrJhqhp/qbNkpewATHlLgTDwBZJgwDELSQkUM6IZzkLPP/t8aZ/NfTm5pw5IQ0J/duRiaFMIlp35sJi5uVAYBgObvJWcC0jz5LzKXn0/Nn/OQnPezJTiq+w46I+xAB5sdEwwDWQKEkH1m5aCgWDNug5tpVnanSODCL5dAHG0YZYXL7/GhLKIDD2ZYRxW4obTlfDAMQtRvoQT9zeJ/JSg1zVnOQ+dSDqBncb+M9zPuT55FUoWyEHDxeBVkAAAAAg4AFwB8NhzSoNzoMc1Ae/7Z55CGHj0gOG/ZVBjb5kbDqY2kAJh07WGAYhcGISKEGB4nAwHMDLqyIAl/t8mPC4rRGEfM1lVH7CAxNfptKDrOKJGwAHAYBrC4iNL6lcZ9HTU3CRlRyCoGVZzuEYuSi/jp604WRZrD9L3pYeQ4FCcTgcuhcLAcvAcWAchQuRgDmABzNhOZwDmDFyvAcwoOVQHLcHleBy6B5fgc34A==", + "program": "56XQKEGJsAEECwZtoIQKD6U2AECBYANKBQfQmwAwQLABwFAoSoAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABDoFCDE2n5MIoBkljDEEk0T4X4nVIptTHIRYNvmbCGAfETvOA2+QQJigw+gUJIEnDmcVRCRXQmYAIq4WDF4UsJ7AwqF1hL9cVIyFlhwCtu/AEMAxIFCSD7zcBFB+AigWTYw5rruXg9CdT0ZA4zl/jbkryjKfPZUy2qtPZa7RiiPFRiHDAMQm0/IQEA4CBwwkgSfDE5n4KrWbIJljsDH8qWizDjD9hqJKs0No6Ez/56tZohz3GDAMVkAAAABA0hULUzbicwRF/qA40iaYDC5vfkvp6JWEjR/ZfyEA20MATLTq3e/FBgGIXAGDmBd77UKPutD6OLpAkNFbS8SYUsZldxGEoA3o23T+XYAMMAxfDghzeJ0MJHAM98RvFz0eCrDowUWd8HlIb3z/dfqm1KEko5YGAcJCcpEkWQAAAAAChV8NRzaCHQWKequSoQRbKj9ZMNUNP9TZslL2ACY8pcCYeALJMGAYhaSEihnRDOchZ5/9vjTP5r6c3NOHJCGhP7tyMTQphEtO/NhMXNyMDAMBzd5KzgWkefJeZS8+n5s/5yE572ZKcVX2HHRH2IAPNjomGAayBQkg+s3JgUCwZt0HNtKs7U6RwYRfLoA42jDLC5ff40JZRAYezLCOK3FDacr4YBiFqN9CCfubxP5KUGuas5yHzqQdQM7jfxnuZ9yfPIqlC2Qg4eLwKsgAAAAEHAAuAPhsOaVBudBjmoD3/bPPIQw8ekBw37KoMbfMjYdTG0gBMOnawwDELgxCRQgwPE4GA5gZdWRAEv9vkx4XFaIwj5msqo/YQGJr9NpQdZxRI2AA4DANYXERpfUrjPo6am4SMqOQVAyrOdwjFyUX8dPWnCyLNYfpe9LDyHAoTicDlCLhYDlGDiwDkKFyMAcpwcwATmBA5VC5UAOVgOYoDmPA5kgeZYDmaOJzOm5lrTjAj/KIPa+tpgioDdK4D5sLjvHJcZ3nlPV1gTcrjhPcoOZYJzMi5a8yAHLkFAzA0g7AOa84nNkbmjsWVABzRhOaZJCDBV5TmDF0qC9egt34IC0AXMfZazwIEv1JHtNvrIs+lbi7RuosMAxWl5vmZ++dy7rFWgYpXOhwsHApv82y3OKNlZ8oFbFvgXmAayBQbmvPzbHCoDmVSKAOZoGkLcA5nQcKA4cBxADxGBztgc8w=", "witness": null } From 2fa2122414241c56ff46312d7db9627d653857b3 Mon Sep 17 00:00:00 2001 From: stringhandler Date: Mon, 29 Jun 2026 15:00:48 +0200 Subject: [PATCH 2/2] refactor: move enum storage from scope to module scope Move enum definitions from the global `Scope.enums` map to per-module storage in `ModuleScope.enums`. This aligns enum scoping with other module-level declarations and simplifies the scope hierarchy. Also refactor `insert_enum` to handle type alias registration directly, eliminating the need for separate `insert_alias` calls during enum declaration analysis. --- src/ast.rs | 82 ++++++++++------------------- src/driver/mod.rs | 5 +- src/driver/resolve_order.rs | 5 ++ src/parse.rs | 100 +++++++++++++++++++----------------- 4 files changed, 89 insertions(+), 103 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 5ef4c72a..3765246c 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -553,6 +553,8 @@ impl_jet_hinter!(CoreJetHinter, Core); #[derive(Clone, Debug, Eq, PartialEq, Default)] struct ModuleScope { aliases: HashMap, + /// Enum definitions declared in this module, keyed by enum name. + enums: HashMap, functions: HashMap, /// Nested inling `mod` blocks, each becoming a child scope. submodules: HashMap, @@ -586,34 +588,6 @@ impl EnumBinding { } } -/// A single enum variant after analysis: its name and u8 discriminant, without source span. -#[derive(Clone, Debug, Eq, PartialEq)] -struct ResolvedEnumVariant { - name: Identifier, - discriminant: u8, -} - -/// The resolved definition of an enum as stored in [`Scope`]: -/// a list of [`ResolvedEnumVariant`]s in declaration order. -#[derive(Clone, Debug, Eq, PartialEq)] -struct EnumBinding { - variants: Arc<[ResolvedEnumVariant]>, -} - -impl EnumBinding { - fn new(variants: Arc<[ResolvedEnumVariant]>) -> Self { - Self { variants } - } - - fn variants(&self) -> &[ResolvedEnumVariant] { - &self.variants - } - - fn contains_variant(&self, name: &Identifier) -> bool { - self.variants.iter().any(|v| &v.name == name) - } -} - /// Scope for generating the abstract syntax tree. /// /// The scope is used for: @@ -631,7 +605,6 @@ struct Scope { /// Block-level variable scopes. Push on block enter, pop on block exit. variables: Vec>, - enums: HashMap, EnumBinding>, parameters: HashMap, witnesses: HashMap, is_main: bool, @@ -654,7 +627,6 @@ impl Scope { module_path: Vec::new(), root: ModuleScope::default(), variables: Vec::new(), - enums: HashMap::new(), parameters: HashMap::new(), witnesses: HashMap::new(), is_main: false, @@ -996,19 +968,27 @@ impl Scope { pub fn insert_enum( &mut self, name: AliasName, + visibility: Visibility, variants: Arc<[ResolvedEnumVariant]>, ) -> Result<(), Error> { - let plug = (name.clone(), self.file_id); - if self.enums.contains_key(&plug) { + if self.current_module().enums.contains_key(&name) + || self.current_module().aliases.contains_key(&name) + { return Err(Error::RedefinedAlias { name }); } - self.enums.insert(plug, EnumBinding::new(variants)); + // An enum is also a `u8` type alias, so its name resolves as a type. + let resolved = self.resolve(&AliasedType::from(UIntType::U8))?; + self.current_module_mut() + .aliases + .insert(name.clone(), (resolved, visibility)); + self.current_module_mut() + .enums + .insert(name, EnumBinding::new(variants)); Ok(()) } pub fn get_enum(&self, name: &AliasName) -> Option<&EnumBinding> { - let plug = (name.clone(), self.file_id); - self.enums.get(&plug) + self.current_module().enums.get(name) } /// Insert a parameter into the global map. @@ -1211,7 +1191,6 @@ impl AbstractSyntaxTree for Item { } parse::Item::Ignored => Ok(Self::Ignored), parse::Item::EnumDeclaration(decl) => { - scope.file_id = decl.file_id(); let n = decl.variants().len(); if n < 2 { return Err(Error::Grammar { @@ -1254,17 +1233,11 @@ impl AbstractSyntaxTree for Item { }) .collect(); scope - .insert_alias(decl.name().clone(), AliasedType::from(UIntType::U8)) - .with_span(decl)?; - scope - .insert_enum(decl.name().clone(), variants) + .insert_enum(decl.name().clone(), decl.visibility().clone(), variants) .with_span(decl)?; Ok(Self::TypeAlias) } - }; - - scope.file_id = previous_file_id; - res + } } } @@ -2022,22 +1995,22 @@ fn desugar_enum_arms_u8( // Bind the scrutinee to a fresh variable to avoid witness-reuse errors. let disc_ident = Identifier::from_str_unchecked("__disc_"); - scope.push_scope(); + scope.enter_block(); scope.insert_variable(disc_ident.clone(), u8_ty.clone()); let analyzed_arms: Vec<(Arc, u8)> = arms .iter() .map(|(e, disc)| { - scope.push_scope(); + scope.enter_block(); let result = Expression::analyze(e, expected_ty, scope).map(|expr| (Arc::new(expr), *disc)); - scope.pop_scope(); + scope.exit_block(); result }) .collect::, _>>()?; let chain = build_u8_chain(&disc_ident, &analyzed_arms, expected_ty, &u8_ty, span); - scope.pop_scope(); + scope.exit_block(); // Wrap in block: { let __disc_N: u8 = scrutinee; } let chain_expr = Arc::new(Expression { @@ -2317,17 +2290,16 @@ mod enum_tests { #[test] fn enum_duplicate_name_is_error() { - use crate::error::ErrorCollector; - let (graph, _ids, _dir) = setup_graph(vec![( - "main.simf", + // Duplicate detection happens during semantic analysis (`Program::analyze`), + // not during flattening, so go through the `analyze` helper like the sibling + // duplicate-variant/discriminant tests. + let result = analyze( "enum Color { Red = 1, Green = 2 } enum Color { Blue = 1, Yellow = 2 } fn main() {}", - )]); - let mut handler = ErrorCollector::new(); - let program_option = graph.linearize_and_build(&mut handler).unwrap(); + ); assert!( - program_option.is_none(), + result.is_err(), "duplicate enum name should cause build failure" ); } diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 9fc3abcd..2983209c 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -410,7 +410,10 @@ impl<'a> ImportContext<'a> { } // These items carry no import information at this stage and can be safely skipped. - parse::Item::TypeAlias(_) | parse::Item::Function(_) | parse::Item::Ignored => {} + parse::Item::TypeAlias(_) + | parse::Item::Function(_) + | parse::Item::EnumDeclaration(_) + | parse::Item::Ignored => {} } } } diff --git a/src/driver/resolve_order.rs b/src/driver/resolve_order.rs index 27c7965c..83d8107e 100644 --- a/src/driver/resolve_order.rs +++ b/src/driver/resolve_order.rs @@ -76,6 +76,11 @@ impl DependencyGraph { function.set_file_id(source_id); Some(parse::Item::Function(function)) } + parse::Item::EnumDeclaration(decl) => { + let mut decl = decl.clone(); + decl.set_file_id(source_id); + Some(parse::Item::EnumDeclaration(decl)) + } parse::Item::Use(use_decl) => Some(self.rewrite_use(source_id, use_decl)), parse::Item::Module(module) => { let items: Vec = module diff --git a/src/parse.rs b/src/parse.rs index 6d1b265c..8c0b2336 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -2295,59 +2295,65 @@ where |_| vec![], ); - just(Token::Match) - .ignore_then(scrutinee) - .then(arms) - .validate(|(scrutinee, arms), e, emit| match arms { - (Some(first), Some(second)) => { - let (left, right) = match (&first.pattern, &second.pattern) { - (MatchPattern::Left(..), MatchPattern::Right(..)) => (first, second), - (MatchPattern::Right(..), MatchPattern::Left(..)) => (second, first), - - (MatchPattern::None, MatchPattern::Some(..)) => (first, second), - (MatchPattern::Some(..), MatchPattern::None) => (second, first), - - (MatchPattern::False, MatchPattern::True) => (first, second), - (MatchPattern::True, MatchPattern::False) => (second, first), - - (p1, p2) => { - emit.emit( - Error::IncompatibleMatchArms { - first: p1.clone(), - second: p2.clone(), - } - .with_span(e.span()), - ); - (first, second) - } - }; + just(Token::Match) + .ignore_then(scrutinee) + .then(arms) + .validate(|(scrutinee, arms), e, emit| { + let all_enum = arms + .iter() + .all(|a| matches!(a.pattern, MatchPattern::EnumVariant(..))); + + if all_enum && arms.len() >= 2 { + return SingleExpressionInner::EnumMatch(EnumMatch { + scrutinee, + arms: Arc::from(arms), + span: e.span(), + }); + } - Self { - scrutinee, - left, - right, - span: e.span(), + // Binary match: exactly 2 non-enum arms. + let fallback_arm = MatchArm { + expression: Arc::new(Expression::empty(Span::new(0, 0))), + pattern: MatchPattern::False, + }; + let (first, second) = if arms.len() == 2 { + let mut it = arms.into_iter(); + (it.next().unwrap(), it.next().unwrap()) + } else { + emit.emit( + Error::Grammar { + msg: "binary match requires exactly 2 arms".to_string(), } - } - _ => { - let match_arm_fallback = MatchArm { - expression: Arc::new(Expression::empty(Span::new(0, 0))), - pattern: MatchPattern::False, - }; + .with_span(e.span()), + ); + (fallback_arm.clone(), fallback_arm) + }; - let (left, right) = ( - arms.0.unwrap_or(match_arm_fallback.clone()), - arms.1.unwrap_or(match_arm_fallback.clone()), + let (left, right) = match (&first.pattern, &second.pattern) { + (MatchPattern::Left(..), MatchPattern::Right(..)) => (first, second), + (MatchPattern::Right(..), MatchPattern::Left(..)) => (second, first), + (MatchPattern::None, MatchPattern::Some(..)) => (first, second), + (MatchPattern::Some(..), MatchPattern::None) => (second, first), + (MatchPattern::False, MatchPattern::True) => (first, second), + (MatchPattern::True, MatchPattern::False) => (second, first), + (p1, p2) => { + emit.emit( + Error::IncompatibleMatchArms { + first: p1.clone(), + second: p2.clone(), + } + .with_span(e.span()), ); - Self { - scrutinee, - left, - right, - span: e.span(), - } + (first, second) } + }; + SingleExpressionInner::Match(Match { + scrutinee, + left, + right, + span: e.span(), }) - } + }) } impl Module {