commit 002fa36ba1f7361ccabee83a3f1ec313ee33e9e2 Author: nick Date: Fri Aug 16 10:22:44 2024 -0400 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d42b544 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/target +lex.yy.* +grammar.tab.* +main.o +distr \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..2f24718 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "distr" +version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..678aa94 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "distr" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +crate-type = ["staticlib"] \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3f8c5d2 --- /dev/null +++ b/Makefile @@ -0,0 +1,46 @@ +CC := gcc +LIBRARIES := -lfl -lm -lreadline + +calc: main.o grammar.tab.o lex.yy.o target/release/libdistr.a + $(CC) -o distr $^ $(LIBRARIES) + +main.o: main.c grammar.tab.h + $(CC) -c $^ + +lex.yy.o: lex.yy.c + $(CC) -c $^ + +lex.yy.c: tokens.lex grammar.tab.h + flex -o lex.yy.c $^ + +grammar.tab.o: grammar.tab.c + $(CC) -c $^ + +grammar.tab.c: grammar.y rust.h + bison -d $< + +grammar.tab.h: grammar.y rust.h + bison -d $< + +rust.h: rust + +.PHONY: rust +rust: + cargo build -r + # cbindgen -o rust.h -l c + +target/release/libdistr.a: rust + +.PHONY: clean +clean: + rm -f lex.yy.* + rm -f grammar.tab.* + rm -f main.o distr + +.PHONY: sr +sr: grammar.y + bison -d -Wcounterexamples $< + +.PHONY: lint +lint: + cargo clippy --no-deps \ No newline at end of file diff --git a/grammar.y b/grammar.y new file mode 100644 index 0000000..5d1b677 --- /dev/null +++ b/grammar.y @@ -0,0 +1,111 @@ +%{ +#include +#include "rust.h" + +#define YYSTYPE BST + +int yyerror(char*); + +extern Variables global_variables; + +%} + +%token NUMBER IDENT ASSIGN +%token PLUS MINUS MUL DIV +%token LESSER GREATER LEFT RIGHT LCURL RCURL COMMA COLON LSQUARE RSQUARE +%token END SYNTAXERROR + +%left PLUS MINUS +%left MUL DIV +%left LESSER GREATER +%left NEG POS SQRT + +%start Input +%% + +Input: + /* empty input */ + | Input Line +; + +Line: + END + // echo expressions to stdout + | Expression END { + print(bst_eval($1, &global_variables)); + } + // assignment + | IDENT ASSIGN Expression END { + insert_variable(&global_variables, $1, bst_eval($3, &global_variables)); + } + | SYNTAXERROR END { + fprintf(stderr, "syntax error\n"); + } +; + +// expressions are anything that can be evaluated +Expression: + // terminal nodes: a literal number and a variable + NUMBER { $$ = yylval; } + | IDENT { $$ = yylval; } + + // term separators + | Expression PLUS Expression { $$ = add($1, $3); } + | Expression MINUS Expression { $$ = sub($1, $3); } + + // mul/div + | Expression MUL Expression { $$ = mul($1, $3); } + | LEFT Expression RIGHT LEFT Expression RIGHT { $$ = mul($2, $5); } + | Expression DIV Expression { $$ = divide($1, $3); } + + // unary + and - + | MINUS Expression %prec NEG { $$ = neg($2); } + | PLUS Expression %prec POS { $$ = $2; } + | SQRT Expression { $$ = sqrt($2); } + + // parenthesis + | LEFT Expression RIGHT { $$ = $2; } + + // there are a few ways to make a distribution + // 1. standard deviation and mean + // { u, o } + | LCURL Expression COMMA Expression RCURL { + $$ = two_var_distr($2, $4); + } + // 2. association + // [ a: b, c: d, e: f ] + | LSQUARE AssoList RSQUARE { + $$ = $2; + } + + | Expression LESSER Expression { + $$ = less($1, $3); + } + | Expression GREATER Expression { + $$ = more($1, $3); + } +; + +AssoList: + /* empty list */ { + $$ = empty_list(); + } + | NonEmpyList { + $$ = $1; + } +; + +NonEmpyList: + Expression COLON Expression { + $$ = np_pair($1, $3); + } + | NonEmpyList COMMA Expression COLON Expression { + $$ = np_pair_push($1, $3, $5); + } +; + +%% + +int yyerror(char* msg) { + return printf("%s\n", msg); +} diff --git a/main.c b/main.c new file mode 100644 index 0000000..58a4321 --- /dev/null +++ b/main.c @@ -0,0 +1,14 @@ +#include + +#include "grammar.tab.h" +#include "rust.h" + +Variables global_variables; + +int main() { + global_variables = new_variables(); + + yyparse(); + + drop_variables(global_variables); +} \ No newline at end of file diff --git a/rust.h b/rust.h new file mode 100644 index 0000000..97cd109 --- /dev/null +++ b/rust.h @@ -0,0 +1,195 @@ +#include +#include +#include +#include + +typedef enum BinaryOperation { + Add, + Sub, + Mul, + Div, + Less, + More, +} BinaryOperation; + +typedef struct EvDistr EvDistr; + +typedef struct HashMap_String__Evaluation HashMap_String__Evaluation; + +typedef struct String String; + +typedef struct Vec_NPPair Vec_NPPair; + +typedef enum Evaluation_Tag { + Number, + MeanDev, +} Evaluation_Tag; + +typedef struct Evaluation { + Evaluation_Tag tag; + union { + struct { + double number; + }; + struct { + double mean; + double dev; + Vec_NPPair* list; + } mean_dev; + }; +} Evaluation; + +typedef enum MathResult_Evaluation_Tag { + /** + * Ok(T) is the Ok variation. Analagous to + * std::result::Result::Ok. + */ + Ok_Evaluation, + /** + * One of the referenced variables in an expression has not been + * assigned yet. + */ + VariableNotFound_Evaluation, + /** + * This is no an error per-se, but rather the result of return with + * no expression to return. This is early-returned to simulate the + * fact that any expression involving null would also return null. + */ + Void_Evaluation, + /** + * The denominator of a division was 0. + */ + ZeroDivision_Evaluation, + /** + * Bad type for operands + */ + TypeError_Evaluation, + /** + * a probability was not in the range [0, 1] + */ + OutOfRange_Evaluation, + /** + * probabilities don't sum to 1 + */ + InsufficientSum_Evaluation, + /** + * operation not supported + */ + NotImplemented_Evaluation, +} MathResult_Evaluation_Tag; + +typedef struct MathResult_Evaluation { + MathResult_Evaluation_Tag tag; + union { + struct { + struct Evaluation ok; + }; + }; +} MathResult_Evaluation; + +/** + * This is all the things a bst node can be. + * a lot of them have 2 levels of heap indirection, + * because for ffi reasons, anything on the heap + * has to be representable as a single pointer. + * only box can do that, and only when given a + * Sized type, so `Box`, not `Box` or + * just `String`. + */ +typedef enum BST_Tag { + Literal, + Variable, + Negate, + Root, + Binary, + ListDistr, + TwoVar, +} BST_Tag; + +typedef struct Negate_Body { + struct BST *operand; +} Negate_Body; + +typedef struct Binary_Body { + enum BinaryOperation op; + struct BST *left; + struct BST *right; +} Binary_Body; + +typedef struct TwoVar_Body { + struct BST *_0; + struct BST *_1; +} TwoVar_Body; + +typedef struct BST { + BST_Tag tag; + union { + struct { + double literal; + }; + struct { + struct String *variable; + }; + Negate_Body negate; + struct { + struct BST *root; + }; + Binary_Body binary; + struct { + struct Vec_NPPair *list_distr; + }; + TwoVar_Body two_var; + }; +} BST; + +typedef struct Variables { + struct HashMap_String__Evaluation *_0; +} Variables; + +struct MathResult_Evaluation bst_eval(struct BST root, const struct Variables *variables); + +struct BST add(struct BST left, struct BST right); + +struct BST sub(struct BST left, struct BST right); + +struct BST mul(struct BST left, struct BST right); + +struct BST divide(struct BST left, struct BST right); + +struct BST less(struct BST left, struct BST right); + +struct BST more(struct BST left, struct BST right); + +struct BST neg(struct BST operand); + +struct BST sqrt(struct BST operand); + +void print(struct MathResult_Evaluation p); + +/** + * # Safety + * you must be able to form a slice from `name` and `len` + */ +struct BST number(const uint8_t *name, uintptr_t len); + +/** + * # Safety + * you must be able to form a slice from `name` and `len` + */ +struct BST variable(const uint8_t *name, uintptr_t len); + +struct BST two_var_distr(struct BST mu, struct BST sig); + +struct BST empty_list(void); + +struct BST np_pair(struct BST left, struct BST right); + +struct BST np_pair_push(struct BST list, struct BST left, struct BST right); + +struct Variables new_variables(void); + +void insert_variable(struct Variables *vars, struct BST name, struct MathResult_Evaluation value); + +void drop_variables(struct Variables v); + +void debug(const struct MathResult_Evaluation *v); diff --git a/src/dbg.rs b/src/dbg.rs new file mode 100644 index 0000000..652bd0c --- /dev/null +++ b/src/dbg.rs @@ -0,0 +1,15 @@ +use crate::{result::MathResult, eval::Evaluation}; + +#[macro_export] +macro_rules! mega_dbg { + ($target:expr, $target_type:tt) => { + eprintln!("[{}:{}] \x1b[31m{:#?}\x1b[0m @ {:?}", file!(), line!(), &$target, &$target as *const $target_type) + }; +} + +#[allow(unused_parens)] +#[no_mangle] +pub extern "C" +fn debug(v: &MathResult) { + mega_dbg!(v, (&MathResult)); +} \ No newline at end of file diff --git a/src/eval/distribution.rs b/src/eval/distribution.rs new file mode 100644 index 0000000..bf5b0b3 --- /dev/null +++ b/src/eval/distribution.rs @@ -0,0 +1,77 @@ +use std::collections::HashMap; + +#[allow(clippy::box_collection)] +#[repr(C)] +#[derive(Debug, Clone)] +pub struct EvDistr { + pub mean: f64, + pub dev: f64, + pub list: Option>>, +} +impl EvDistr { + pub fn modify(self, n: f64, mut f: MeanFn, d: DevFn) -> Self + where + MeanFn: FnMut(f64, f64) -> f64, + DevFn: FnOnce(f64, f64) -> f64, + { + let list = if let Some(mut boxed_list) = self.list { + // using map in place to avoid dereferencing then boxing a value + // repeatedly. it's already inefficient to use the heap + // but necessary for ffi reasons. the least we can do is try. + map_in_place(boxed_list.as_mut(), |&(x, p)| (f(x, n), p)); + Some(boxed_list) + } else { + None + }; + + Self { + mean: f(self.mean, n), + dev: d(self.dev, n), + list, + } + } + + pub fn combine(self, rhs: Self, mut f: MeanFn) -> Self + where + MeanFn: FnMut(f64, f64) -> f64, + { + let mean = f(self.mean, rhs.mean); + let dev = (self.dev*self.dev + rhs.dev*rhs.dev).powf(0.5); + let list = match (self.list, rhs.list) { + (None, None) => None, + (Some(l), None) | (None, Some(l)) => Some(l), + (Some(left), Some(right)) => Some(Box::new(combine_lists(*left, *right, f))), + }; + + Self { + mean, dev, list + } + } +} + +#[inline] +fn map_in_place(l: &mut [T], mut f: F) +where + F: FnMut(&T) -> T +{ + l.iter_mut().for_each(|refm| *refm = f(refm)) +} + +fn combine_lists(left: Vec<(f64, f64)>, right: Vec<(f64, f64)>, mut f: F) -> Vec<(f64, f64)> +where + F: FnMut(f64, f64) -> f64, +{ + let mut map = HashMap::new(); + + for (n1, p1) in left { + for &(n2, p2) in &right { + let key: u64 = f(n1, n2).to_bits(); + let value = p1*p2; + *map.entry(key).or_insert(0.0) += value; + } + } + + map.into_iter() + .map(|(n, p)| (f64::from_bits(n), p)) + .collect() +} \ No newline at end of file diff --git a/src/eval/mod.rs b/src/eval/mod.rs new file mode 100644 index 0000000..ab434b1 --- /dev/null +++ b/src/eval/mod.rs @@ -0,0 +1,221 @@ +use std::fmt::Display; +use std::ops::{Add, Sub, Mul, Div}; + +pub mod distribution; + +pub use distribution::EvDistr; + +use crate::result::MathResult; +use MathResult::*; + +macro_rules! full_impl { + ($left:expr, $right:expr, $mean:expr) => {{ + use Evaluation as E; + match ($left, $right) { + (E::Number(n1), E::Number(n2)) => Ok(E::Number($mean(n1, n2))), + (E::Number(n),E::MeanDev(d)) => { + Ok(E::MeanDev(d.modify(n, $mean, |left, _right| left))) + }, + (E::MeanDev(d),E::Number(n)) => { + Ok(E::MeanDev(d.modify(n, $mean, |left, _right| left))) + }, + (E::MeanDev(d1), E::MeanDev(d2)) => Ok(E::MeanDev(d1.combine(d2, $mean))) + } + }}; +} + +macro_rules! partial_impl { + ($left:expr, $right:expr, $function:expr) => {{ + use Evaluation as E; + match ($left, $right) { + (Self::Number(n1),Self::Number(n2)) => Ok(Self::Number($function(n1, n2))), + (Self::Number(n),Self::MeanDev(d)) => { + Ok(E::MeanDev(d.modify(n, $function, $function))) + }, + (Self::MeanDev(d),Self::Number(n)) => { + Ok(E::MeanDev(d.modify(n, $function, $function))) + }, + (Self::MeanDev(..),Self::MeanDev(..)) => MathResult::NotImplemented, + } + }}; +} + +#[repr(C)] +#[derive(Debug, Clone)] +pub enum Evaluation { + Number(f64), + MeanDev(EvDistr), +} +impl Evaluation { + pub fn two_var(mean: f64, dev: f64) -> Self { + Self::MeanDev(EvDistr { mean, dev, list: None }) + } + + pub fn from_list(l: Vec<(Evaluation, Evaluation)>) -> MathResult { + let mut mean = 0.0; + let list = l.into_iter() + .map( + |pair| + match pair { + (Evaluation::Number(n), Evaluation::Number(p)) => { + if (0.0..=1.0).contains(&p) { + Ok((n, p)) + } else { + MathResult::OutOfRange + } + }, + _ => MathResult::TypeError, + } + ).collect::< MathResult< Vec<(f64, f64)> > >()?; + + let mut p_sum = 0.0; + for &(n, p) in &list { + mean += n*p; + p_sum += p; + } + if p_sum != 1.0 { + return MathResult::InsufficientSum; + } + + // sqrt(sum( p*(x_i - u)^2 )) + let dev = list.iter() + .map(|&(n, p)| (n - mean).powi(2) * p) + .sum::() + .powf(0.5); + + Ok(Evaluation::MeanDev(EvDistr { mean, dev, list: Some(Box::new(list)) })) + } + + pub fn add(self, rhs: Self) -> MathResult { + full_impl!(self, rhs, f64::add) + } + + pub fn sub(self, rhs: Self) -> MathResult { + full_impl!(self, rhs, f64::sub) + } + + pub fn mul(self, rhs: Self) -> MathResult { + partial_impl!(self, rhs, f64::mul) + } + + pub fn div(self, rhs: Self) -> MathResult { + partial_impl!(self, rhs, f64::div) + } + + // always returns the Number variant + pub fn p_less(self, rhs: Self) -> MathResult { + // p_unchecked is made by the imperfect approximation + // of the integral. it is not guaranteed to be in the + // range of [0, 1] + let p_unchecked = match (self, rhs) { + (Self::Number(n1), Self::Number(n2)) => { + if n1 < n2 { 1.0 } + else { 0.0 } + }, + (Self::Number(n), Self::MeanDev(d)) => { + more(d, n) + }, + (Self::MeanDev(d), Self::Number(n)) => { + less(d, n) + }, + (this @ Self::MeanDev(..), other @ Self::MeanDev(..)) => { + // P(self < rhs) = P(self-rhs < 0) + return this.sub(other)?.p_less(Evaluation::Number(0.0)); + } + }; + + // p clamps p_unchecked to [0, 1] + // this happens when the approximation gets too inaccurate, and + // reads off massive floating point values. + // for example {3, 5} < 100 can return numbers with dozens of digits + let p = if p_unchecked > 1.0 { + 1.0 + } else if p_unchecked < 0.0 { + 0.0 + } else { + p_unchecked + }; + + Ok(Self::Number(p)) + } + + // always returns the Number variant + pub fn p_more(self, rhs: Self) -> MathResult { + rhs.p_less(self) + } +} +impl Display for Evaluation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Number(n) => write!(f, "{n}"), + Self::MeanDev(d) => write!(f, "μ = {}, σ = {:.3}", d.mean, d.dev), + } + } +} + +fn less(d: EvDistr, n: f64) -> f64 { + match d.list { + None => less_continuous(d.mean, d.dev, n), + Some(l) => { + l.into_iter() + .filter(|pair| pair.0 < n) + .map(|pair| pair.1) + .sum() + } + } +} + +fn more(d: EvDistr, n: f64) -> f64 { + match d.list { + None => 1.0 - less_continuous(d.mean, d.dev, n), + Some(l) => { + l.into_iter() + .filter(|pair| pair.0 > n) + .map(|pair| pair.1) + .sum() + } + } +} + +/// Return the probability that a distribution with mean +/// `mean` and standard deviation `stddev` is less than `n` +fn less_continuous(mean: f64, stddev: f64, x: f64) -> f64 { + /* + The equation for the normal distribution is + + exp( -(x-u)^2 / 2o^2 ) / o / sqrt(2 * pi) + + This can't actually be integrated. + + However! You can approximate e^x with a taylor series + + so when you divide again by the o * sqrt(2pi) term, you can approximate + integrals! + */ + + const TERMS: usize = 50; + const FRAC_1_SQRT_2_PI: f64 = std::f64::consts::FRAC_1_SQRT_PI * std::f64::consts::FRAC_1_SQRT_2; + let coefficient = FRAC_1_SQRT_2_PI/stddev; + + let mut result = 0.0; + + // diff is the difference, diff_squared is the square of the difference + // done by expanding (x-mean)^2 + let diff = x-mean; + let diff_squared = x*x - 2.0*mean*x + mean*mean; + let twice_variance = 2.0 * stddev * stddev; + let term_coeff = diff_squared/twice_variance; + + for n in 0..=TERMS { + // this loop ends adding a ton of very small numbers. minimizing + // floating point errors is a must + let mut term = diff / (2*n + 1) as f64; + for fact in 1..=n { term *= term_coeff/fact as f64 } + if n%2 == 1 { term *= -1.0 } + result += term; + } + + // the improper integral subtracts the value of negative infinity, + // which is -0.5 by definition + coefficient * result + 0.5 +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..dbaa1f5 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,240 @@ +#![feature(try_trait_v2, more_float_constants)] + +mod result; +mod eval; +mod variables; +mod dbg; + +use crate::result::{MathResult, IntoMath}; +use MathResult::*; + +/// tuples are frustrating over ffi. +/// use a struct instead +#[repr(C)] +#[derive(Debug, Clone)] +pub struct NPPair { + n: BST, + p: BST, +} + +use eval::Evaluation; +use variables::Variables; + +/// This is all the things a bst node can be. +/// a lot of them have 2 levels of heap indirection, +/// because for ffi reasons, anything on the heap +/// has to be representable as a single pointer. +/// only box can do that, and only when given a +/// Sized type, so `Box`, not `Box` or +/// just `String`. +#[repr(C)] +#[derive(Debug, Clone)] +pub enum BST { + // valid/passthrough nodes + Literal(f64), + Variable(Box), + + // operations + Negate { + operand: Box, + }, + Root(Box), + + Binary { + op: BinaryOperation, + left: Box, + right: Box, + }, + + // values passed from bison that must + // be converted + ListDistr(Box>), + TwoVar(Box, Box), +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub enum BinaryOperation { + Add, + Sub, + Mul, + Div, + Less, + More, +} +impl BinaryOperation { + fn eval(self, left: Evaluation, right: Evaluation) -> MathResult { + match self { + Self::Add => left.add(right), + Self::Sub => left.sub(right), + Self::Mul => left.mul(right), + Self::Div => left.div(right), + Self::Less => left.p_less(right), + Self::More => left.p_more(right), + } + } +} + +#[no_mangle] +pub extern "C" +fn bst_eval(root: BST, variables: &Variables) -> MathResult { + match root { + BST::Literal(n) => Ok(Evaluation::Number(n)), + + BST::Binary { op, left, right } => { + let left = bst_eval(*left, variables)?; + let right = bst_eval(*right, variables)?; + op.eval(left, right) + }, + BST::Negate { operand } => { + let operand = bst_eval(*operand, variables)?; + let operation = BinaryOperation::Mul; + let factor = Evaluation::Number(-1.0); + operation.eval(factor, operand) + }, + BST::Root(operand) => { + let intermediate = bst_eval(*operand, variables)?; + match intermediate { + Evaluation::Number(n) => MathResult::Ok(Evaluation::Number(n.sqrt())), + Evaluation::MeanDev(..) => MathResult::NotImplemented, + } + } + + BST::Variable(name) => variables.0 + .get(Box::as_ref(&name)) + .cloned() + .into_math(MathResult::VariableNotFound), + + BST::TwoVar(mean, dev) => { + let Evaluation::Number(u) = bst_eval(*mean, variables)? + else { return MathResult::TypeError }; + let Evaluation::Number(o) = bst_eval(*dev, variables)? + else { return MathResult::TypeError }; + Ok(Evaluation::two_var(u, o)) + }, + BST::ListDistr(list) => { + let list = (*list).into_iter() + .map( + |pair| + Ok((bst_eval(pair.n, variables)?, bst_eval(pair.p, variables)?)) + ) + .collect::< MathResult> >()?; + Evaluation::from_list(list) + }, + } +} + +macro_rules! binary { + ($variant:ident, $left:ident, $right:ident) => { + BST::Binary { + op: BinaryOperation::$variant, + left: Box::new($left), + right: Box::new($right), + } + }; +} + +#[no_mangle] +pub extern "C" +fn add(left: BST, right: BST) -> BST { + binary!(Add, left, right) +} + +#[no_mangle] +pub extern "C" +fn sub(left: BST, right: BST) -> BST { + binary!(Sub, left, right) +} + +#[no_mangle] +pub extern "C" +fn mul(left: BST, right: BST) -> BST { + binary!(Mul, left, right) +} + +#[no_mangle] +pub extern "C" +fn divide(left: BST, right: BST) -> BST { + binary!(Div, left, right) +} + +#[no_mangle] +pub extern "C" +fn less(left: BST, right: BST) -> BST { + binary!(Less, left, right) +} + +#[no_mangle] +pub extern "C" +fn more(left: BST, right: BST) -> BST { + binary!(More, left, right) +} + +#[no_mangle] +pub extern "C" +fn neg(operand: BST) -> BST { + BST::Negate { operand: Box::new(operand) } +} + +#[no_mangle] +pub extern "C" +fn sqrt(operand: BST) -> BST { + BST::Root(Box::new(operand)) +} + +#[no_mangle] +pub extern "C" +fn print(p: MathResult) { + match p { + Ok(n) => println!(">>> {n}"), + other => println!("{other}"), + } +} + +/// # Safety +/// you must be able to form a slice from `name` and `len` +#[no_mangle] +pub unsafe extern "C" +fn number(name: *const u8, len: usize) -> BST { + let slice = unsafe { std::slice::from_raw_parts(name, len) }; + let str = unsafe { std::str::from_utf8_unchecked(slice) }; + BST::Literal(unsafe { str.parse().unwrap_unchecked() }) +} + +/// # Safety +/// you must be able to form a slice from `name` and `len` +#[no_mangle] +pub unsafe extern "C" +fn variable(name: *const u8, len: usize) -> BST { + let slice = unsafe { std::slice::from_raw_parts(name, len) }; + let str = unsafe { std::str::from_utf8_unchecked(slice) }; + let string = str.to_owned(); + BST::Variable(Box::new(string)) +} + +#[no_mangle] +pub extern "C" +fn two_var_distr(mu: BST, sig: BST) -> BST { + BST::TwoVar(Box::new(mu), Box::new(sig)) +} + +#[no_mangle] +pub extern "C" +fn empty_list() -> BST { + BST::ListDistr(Box::default()) +} + +#[no_mangle] +pub extern "C" +fn np_pair(left: BST, right: BST) -> BST { + BST::ListDistr(Box::new(vec![ NPPair { n: left, p: right } ])) +} + +#[no_mangle] +pub extern "C" +fn np_pair_push(list: BST, left: BST, right: BST) -> BST { + let BST::ListDistr(mut l) = list else { unsafe { std::hint::unreachable_unchecked() } }; + l.push( NPPair { n: left, p: right } ); + BST::ListDistr(l) +} + diff --git a/src/result.rs b/src/result.rs new file mode 100644 index 0000000..db20db0 --- /dev/null +++ b/src/result.rs @@ -0,0 +1,193 @@ +use std::convert::Infallible; +use std::hint::unreachable_unchecked; +/// Define the FFI-safe MathResult type +/// this would be a regular type alias, but +/// Result isn't #[repr(C)], and +/// therefore not FFI-safe + +use std::ops::{FromResidual, Try, ControlFlow}; +use std::fmt::Display; + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub enum MathResult { + /// Ok(T) is the Ok variation. Analagous to + /// std::result::Result::Ok. + Ok(T), + /// One of the referenced variables in an expression has not been + /// assigned yet. + VariableNotFound, + /// This is no an error per-se, but rather the result of return with + /// no expression to return. This is early-returned to simulate the + /// fact that any expression involving null would also return null. + Void, + /// The denominator of a division was 0. + ZeroDivision, + /// Bad type for operands + TypeError, + /// a probability was not in the range [0, 1] + OutOfRange, + /// probabilities don't sum to 1 + InsufficientSum, + /// operation not supported + NotImplemented, +} + +impl FromResidual> for MathResult { + /// This implementation of residual is a bit weird. The residual + /// is of type `MathResult` because the error variants are effectively + /// untyped, in that they do not have any generic or concrete types. + /// # Undefined behaviour + /// This function assumes that it will only be called by Rust to + /// reconstitute a `MathResult` from one of its error variants, and + /// as such it is **Undefined Behaviour** to call this on a `MathResult::Ok` + fn from_residual(this: MathResult) -> Self { + match this { + MathResult::Ok(_) => unsafe { unreachable_unchecked() }, + MathResult::VariableNotFound => Self::VariableNotFound, + MathResult::Void => Self::Void, + MathResult::ZeroDivision => Self::ZeroDivision, + MathResult::TypeError => Self::TypeError, + MathResult::InsufficientSum => Self::InsufficientSum, + MathResult::OutOfRange => Self::OutOfRange, + MathResult::NotImplemented => Self::NotImplemented, + } + } +} + +impl Try for MathResult { + type Output = T; + type Residual = Self; + + fn from_output(data: T) -> Self { + Self::Ok(data) + } + + fn branch(self) -> ControlFlow { + // this is a weird match, but essentialy the Ok data can + // get `Continue`d, while others get `Break`ed. + match self { + MathResult::Ok(data) => ControlFlow::Continue(data), + other => ControlFlow::Break(other), + } + } +} + +/// For a given Result>, convert into a +/// MathResult by turning Ok(t) into MathResult::Ok(t) and Err(n) +/// into Self::n. E in the Result must be of type MathResult +/// to avoid the possibility of Err(Ok(_)). +impl From< Result< OkType, MathResult > > for MathResult { + fn from(value: Result< OkType, MathResult >) -> MathResult { + match value { + Ok(data) => Self::Ok(data), + // this is unreachable because this is Err(Ok(Infalliable)) + Err(MathResult::Ok(_)) => unsafe { unreachable_unchecked() }, + Err(MathResult::VariableNotFound) => Self::VariableNotFound, + Err(MathResult::ZeroDivision) => Self::ZeroDivision, + Err(MathResult::Void) => Self::Void, + Err(MathResult::TypeError) => Self::TypeError, + Err(MathResult::OutOfRange) => Self::OutOfRange, + Err(MathResult::InsufficientSum) => Self::InsufficientSum, + Err(MathResult::NotImplemented) => Self::NotImplemented, + } + } +} + +/// For a given MathResult, turn it into a +/// Result by mapping Self::Ok(data) to +/// Ok(data), and any error value to Err(value) +impl From< MathResult > for Result> { + fn from(val: MathResult) -> Self { + match val { + MathResult::Ok(data) => Ok(data), + MathResult::VariableNotFound => Err(MathResult::VariableNotFound), + MathResult::ZeroDivision => Err(MathResult::ZeroDivision), + MathResult::Void => Err(MathResult::Void), + MathResult::TypeError => Err(MathResult::TypeError), + MathResult::OutOfRange => Err(MathResult::OutOfRange) , + MathResult::InsufficientSum => Err(MathResult::InsufficientSum), + MathResult::NotImplemented => Err(MathResult::NotImplemented), + } + } +} + +impl +FromIterator< MathResult > for MathResult +where + Collection: FromIterator +{ + fn from_iter(iter: I) -> MathResult + where + I: IntoIterator> + { + let result_iterator = iter.into_iter().map( + // this hellish generic selects the Into implementation that converts + // a MathResult into a Result>, so that I + // can use Result's implementation of FromIterator later. + as Into< Result> >>::into + ); + // Use Result's implementation of FromIterator to build up the + // collection as a Result. + let collected_result: Result> = result_iterator.collect(); + // convert the result into a MathResult. The Err variant is + // of type MathResult so this is a flawless conversion. + collected_result.into() + } +} + + +impl Display for MathResult +where T: Display { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // pretty self-explanatory. write the right message for each variant + match self { + Self::Ok(data) => write!(f, "{data}"), + Self::VariableNotFound => write!(f, "variable not found"), + Self::ZeroDivision => write!(f, "division by zero"), + Self::Void => write!(f, "(void)"), + Self::TypeError => write!(f, "bad type"), + Self::InsufficientSum => write!(f, "probabilities must sum to 1"), + Self::OutOfRange => write!(f, "probabilities must be in the range [0, 1]"), + Self::NotImplemented => write!(f, "operation not supported"), + } + } +} + +mod seal_mod { + /// This is just a sanity check to make sure I never try to ipmlement + /// IntoMath elsewhere in the codebase. + pub trait Sealed {} +} +/// Ugly hack to simplify moving between rust's error-handling types `Option` +/// and `Result` and my `MathResult`. +/// This is `Sealed` for sanity reasons. +pub trait IntoMath: seal_mod::Sealed { + type Target; + + fn into_math(self, fallback: MathResult) -> MathResult; +} + +impl seal_mod::Sealed for Option {} +impl IntoMath for Option { + type Target = T; + + fn into_math(self, fallback: MathResult) -> MathResult { + match self { + Some(data) => MathResult::Ok(data), + None => fallback, + } + } +} + +impl seal_mod::Sealed for Result {} +impl IntoMath for Result { + type Target = T; + + fn into_math(self, fallback: MathResult) -> MathResult { + match self { + Ok(data) => MathResult::Ok(data), + Err(_) => fallback, + } + } +} \ No newline at end of file diff --git a/src/variables.rs b/src/variables.rs new file mode 100644 index 0000000..2458028 --- /dev/null +++ b/src/variables.rs @@ -0,0 +1,33 @@ +use std::collections::HashMap; +use std::hint::unreachable_unchecked; + +use crate::{eval::Evaluation, BST}; +use crate::result::MathResult; + +#[allow(clippy::box_collection)] +#[repr(C)] +#[derive(Debug, Clone, Default)] +pub struct Variables(pub Box>); + +#[no_mangle] +pub extern "C" +fn new_variables() -> Variables { + Variables::default() +} + +#[no_mangle] +pub extern "C" +fn insert_variable(vars: &mut Variables, name: BST, value: MathResult) { + let MathResult::Ok(v) = value else { + eprintln!("{value}"); + return + }; + let BST::Variable(k) = name else {unsafe { unreachable_unchecked() }}; + vars.0.insert(*k, v); +} + +#[no_mangle] +pub extern "C" +fn drop_variables(v: Variables) { + drop(v) +} \ No newline at end of file diff --git a/tokens.lex b/tokens.lex new file mode 100644 index 0000000..6d88202 --- /dev/null +++ b/tokens.lex @@ -0,0 +1,80 @@ +%{ +#include "rust.h" +#define YYSTYPE BST +#include "grammar.tab.h" // PLUS MINUS etc. + +extern YYSTYPE yylval; + +#include +#include +#define YY_INPUT(buf,result,max_size) result = mygetinput(buf, max_size); + +static int mygetinput(char *buf, int size) { + char *line; + + if (feof(yyin)) + return YY_NULL; + + line = readline("> "); + + if(line == NULL) + return YY_NULL; + + size_t len = strnlen(line, size); + if (len > size-2) + return YY_NULL; + + snprintf(buf, size, "%s\n", line); + + add_history(line); + free(line); + + return strlen(buf); +} + +%} + +%option noyywrap + +white [ \t]+ +sign (\+?|-) +digits [0-9]+ +number {digits}(\.{digits}|e{sign}{digits})? +ident [a-zA-Z][a-zA-Z0-9_]* + +%% + +{white} { ; } + +{number} { + yylval = number(yytext, yyleng); + return NUMBER; +} + +{ident} { + yylval = variable(yytext, yyleng); + return IDENT; +} + +"+" return PLUS; +"-" return MINUS; +"*" return MUL; +"/" return DIV; +"=" return ASSIGN; +"(" return LEFT; +")" return RIGHT; +"{" return LCURL; +"}" return RCURL; +"[" return LSQUARE; +"]" return RSQUARE; +"," return COMMA; +":" return COLON; +"<" return LESSER; +">" return GREATER; +"~" return SQRT; + +"\n" return END; + +. return SYNTAXERROR; + +%%