Skip to content

Commit 19f1c77

Browse files
embbed .py files in the binary (#214)
* move rec_aggregation .py files into zkdsl_implem/ * embedd zk_dsl files in `rec_aggregation` crate into the binary. Alternative to #212 Co-authored-by: Kevaundray Wedderburn <kevtheappdev@gmail.com> --------- Co-authored-by: Tom Wambsgans <TomWambsgans@users.noreply.github.com> Co-authored-by: Kevaundray Wedderburn <kevtheappdev@gmail.com>
1 parent ade3a0c commit 19f1c77

18 files changed

Lines changed: 115 additions & 38 deletions

File tree

Cargo.lock

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ tracing-subscriber = { version = "0.3.23", features = ["std", "env-filter"] }
7979
tracing-forest = { version = "0.3.0", features = ["ansi", "smallvec"] }
8080
postcard = { version = "1.1.3", features = ["alloc"] }
8181
lz4_flex = "0.13.0"
82+
include_dir = "0.7"
8283

8384
[features]
8485
prox-gaps-conjecture = ["rec_aggregation/prox-gaps-conjecture"]

crates/lean_compiler/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ xmss.workspace = true
1414
rand.workspace = true
1515

1616
tracing.workspace = true
17+
include_dir.workspace = true
1718
sub_protocols.workspace = true
1819
lean_vm.workspace = true
1920
backend.workspace = true

crates/lean_compiler/src/lib.rs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,26 +91,30 @@ impl From<RunnerError> for Error {
9191
pub enum ProgramSource {
9292
Raw(String),
9393
Filepath(String),
94+
Embedded {
95+
entry: String,
96+
dir: &'static include_dir::Dir<'static>,
97+
},
9498
}
9599

96100
impl ProgramSource {
97101
pub fn get_content(&self, flags: &CompilationFlags) -> Result<String, String> {
98-
match self {
99-
ProgramSource::Raw(src) => {
100-
let mut result = src.clone();
101-
for (key, value) in flags.replacements.iter() {
102-
result = result.replace(key, value);
103-
}
104-
Ok(result)
105-
}
102+
let raw = match self {
103+
ProgramSource::Raw(src) => src.clone(),
106104
ProgramSource::Filepath(fp) => {
107-
let mut result = std::fs::read_to_string(fp).map_err(|e| format!("Failed to read file {fp}: {e}"))?;
108-
for (key, value) in flags.replacements.iter() {
109-
result = result.replace(key, value);
110-
}
111-
Ok(result)
105+
std::fs::read_to_string(fp).map_err(|e| format!("Failed to read file {fp}: {e}"))?
112106
}
107+
ProgramSource::Embedded { entry, dir } => dir
108+
.get_file(entry)
109+
.and_then(|f| f.contents_utf8())
110+
.ok_or_else(|| format!("Embedded entry '{entry}' not found or not valid UTF-8"))?
111+
.to_string(),
112+
};
113+
let mut result = raw;
114+
for (key, value) in flags.replacements.iter() {
115+
result = result.replace(key, value);
113116
}
117+
Ok(result)
114118
}
115119
}
116120

crates/lean_compiler/src/parser/parsers/mod.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,24 @@ pub struct ParseContext {
100100
pub next_file_id: usize,
101101
/// Compilation flags
102102
pub flags: CompilationFlags,
103+
/// If `Some`, imports resolve against this embedded directory instead of the filesystem.
104+
pub embedded_dir: Option<&'static include_dir::Dir<'static>>,
103105
}
104106

105107
impl ParseContext {
106108
pub fn new(input: &ProgramSource, flags: CompilationFlags) -> Result<Self, SemanticError> {
107-
let current_source_code = input.get_content(&flags).unwrap();
108-
let (current_filepath, imported_filepaths) = match input {
109-
ProgramSource::Raw(_) => ("<raw_input>".to_string(), BTreeSet::new()),
109+
let current_source_code = input.get_content(&flags).map_err(SemanticError::new)?;
110+
let (current_filepath, imported_filepaths, embedded_dir) = match input {
111+
ProgramSource::Raw(_) => ("<raw_input>".to_string(), BTreeSet::new(), None),
110112
ProgramSource::Filepath(fp) => {
111113
let canonical = std::fs::canonicalize(fp)
112114
.map_err(|e| SemanticError::new(format!("Cannot resolve filepath '{}': {}", fp, e)))?
113115
.to_string_lossy()
114116
.to_string();
115-
(canonical.clone(), [canonical].into_iter().collect())
117+
(canonical.clone(), [canonical].into_iter().collect(), None)
118+
}
119+
ProgramSource::Embedded { entry, dir } => {
120+
(entry.clone(), [entry.clone()].into_iter().collect(), Some(*dir))
116121
}
117122
};
118123
let import_stack = vec![current_filepath.clone()];
@@ -132,6 +137,7 @@ impl ParseContext {
132137
current_source_code,
133138
next_file_id: 1,
134139
flags,
140+
embedded_dir,
135141
})
136142
}
137143

crates/lean_compiler/src/parser/parsers/program.rs

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,30 @@ impl Parse<Program> for ProgramParser {
5656
ctx.import_root.clone()
5757
};
5858
let raw_path = Path::new(&base_dir).join(&relative_path);
59-
let filepath = raw_path
60-
.canonicalize()
61-
.map_err(|e| {
62-
SemanticError::new(format!(
63-
"Cannot resolve import '{}' (resolved to '{}'): {}",
64-
relative_path,
65-
raw_path.display(),
66-
e
59+
let filepath = if let Some(dir) = ctx.embedded_dir {
60+
let key = lexical_normalize(&raw_path);
61+
if dir.get_file(Path::new(&key)).is_none() {
62+
return Err(SemanticError::new(format!(
63+
"Cannot resolve embedded import '{}' (resolved to '{}')",
64+
relative_path, key
6765
))
68-
})?
69-
.to_string_lossy()
70-
.to_string();
66+
.into());
67+
}
68+
key
69+
} else {
70+
raw_path
71+
.canonicalize()
72+
.map_err(|e| {
73+
SemanticError::new(format!(
74+
"Cannot resolve import '{}' (resolved to '{}'): {}",
75+
relative_path,
76+
raw_path.display(),
77+
e
78+
))
79+
})?
80+
.to_string_lossy()
81+
.to_string()
82+
};
7183

7284
// Check for circular imports
7385
if ctx.import_stack.contains(&filepath) {
@@ -100,7 +112,15 @@ impl Parse<Program> for ProgramParser {
100112
}
101113
ctx.imported_filepaths.insert(filepath.clone());
102114
ctx.import_stack.push(filepath.clone());
103-
ctx.current_source_code = ProgramSource::Filepath(filepath).get_content(&ctx.flags)?;
115+
let import_source = if let Some(dir) = ctx.embedded_dir {
116+
ProgramSource::Embedded {
117+
entry: filepath.clone(),
118+
dir,
119+
}
120+
} else {
121+
ProgramSource::Filepath(filepath.clone())
122+
};
123+
ctx.current_source_code = import_source.get_content(&ctx.flags)?;
104124
let subprogram = parse_program_helper(ctx)?;
105125
ctx.import_stack.pop();
106126
functions.extend(subprogram.functions);
@@ -143,6 +163,29 @@ impl Parse<Program> for ProgramParser {
143163
}
144164
}
145165

166+
/// Lexically normalize a path for embedded-source lookups: collapse `.` and
167+
/// `..` components and join with `/` regardless of host OS, so the same key
168+
/// works on every platform.
169+
fn lexical_normalize(path: &Path) -> String {
170+
use std::path::Component;
171+
let mut parts: Vec<String> = Vec::new();
172+
for c in path.components() {
173+
match c {
174+
Component::CurDir => {}
175+
Component::ParentDir => {
176+
if matches!(parts.last().map(String::as_str), Some("..") | None) {
177+
parts.push("..".to_string());
178+
} else {
179+
parts.pop();
180+
}
181+
}
182+
Component::Normal(s) => parts.push(s.to_string_lossy().into_owned()),
183+
Component::RootDir | Component::Prefix(_) => {}
184+
}
185+
}
186+
parts.join("/")
187+
}
188+
146189
/// Parser for import statements.
147190
pub struct ImportStatementParser;
148191

crates/lean_compiler/zkDSL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ dot_product_ee(x, y, z) # z = x * y
506506
507507
# Copy extension element (multiply by [1,0,0,0,0]).
508508
# `ONE_EF_PTR` is a guest-program constant that the program must materialize
509-
# in its preamble memory at startup; see `crates/rec_aggregation/utils.py`
509+
# in its preamble memory at startup; see `crates/rec_aggregation/zkdsl_implem/utils.py`
510510
# for an example (`build_preamble_memory`).
511511
dot_product_ee(src, ONE_EF_PTR, dst)
512512

crates/rec_aggregation/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ xmss.workspace = true
1616
rand.workspace = true
1717

1818
tracing.workspace = true
19+
include_dir.workspace = true
1920
sub_protocols.workspace = true
2021
lean_vm.workspace = true
2122
lean_compiler.workspace = true

crates/rec_aggregation/src/compilation.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use lean_prover::{
66
};
77
use lean_vm::*;
88
use std::collections::{BTreeMap, HashMap};
9-
use std::path::Path;
109
use std::sync::OnceLock;
1110
use sub_protocols::{N_VARS_TO_SEND_GKR_COEFFS, min_stacked_n_vars, total_whir_statements};
1211
use tracing::instrument;
@@ -38,6 +37,8 @@ pub fn init_aggregation_bytecode() {
3837
BYTECODE.get_or_init(compile_main_program_self_referential);
3938
}
4039

40+
static EMBEDDED_ZK_DSL: include_dir::Dir<'_> = include_dir::include_dir!("$CARGO_MANIFEST_DIR/zkdsl_implem");
41+
4142
pub const MAX_RECURSIONS: usize = 16;
4243
pub const MAX_XMSS_AGGREGATED: usize = 1 << 15; // TODO increase (we would need a bigger minimal memory size, totally doable)
4344
pub const MAX_XMSS_DUPLICATES: usize = 1 << 15; // ...same
@@ -69,12 +70,11 @@ pub(crate) fn type1_input_data_size_padded(program_log_size: usize) -> usize {
6970
fn compile_main_program(program_log_size: usize, bytecode_zero_eval: F) -> Bytecode {
7071
let replacements = build_replacements(program_log_size, bytecode_zero_eval);
7172

72-
let filepath = Path::new(env!("CARGO_MANIFEST_DIR"))
73-
.join("main.py")
74-
.to_str()
75-
.unwrap()
76-
.to_string();
77-
compile_program_with_flags(&ProgramSource::Filepath(filepath), CompilationFlags { replacements })
73+
let source = ProgramSource::Embedded {
74+
entry: "main.py".to_string(),
75+
dir: &EMBEDDED_ZK_DSL,
76+
};
77+
compile_program_with_flags(&source, CompilationFlags { replacements })
7878
}
7979

8080
#[instrument(skip_all)]

crates/rec_aggregation/tests/test_log2_ceil.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from snark_lib import *
2-
from ..utils import *
2+
from ..zkdsl_implem.utils import *
33

44

55
def main():

0 commit comments

Comments
 (0)