From 002fa36ba1f7361ccabee83a3f1ec313ee33e9e2 Mon Sep 17 00:00:00 2001 From: nick Date: Fri, 16 Aug 2024 10:22:44 -0400 Subject: [PATCH] Initial commit --- .gitignore | 5 + Cargo.lock | 7 ++ Cargo.toml | 9 ++ Makefile | 46 ++++++++ grammar.y | 111 ++++++++++++++++++ main.c | 14 +++ rust.h | 195 +++++++++++++++++++++++++++++++ src/dbg.rs | 15 +++ src/eval/distribution.rs | 77 +++++++++++++ src/eval/mod.rs | 221 +++++++++++++++++++++++++++++++++++ src/lib.rs | 240 +++++++++++++++++++++++++++++++++++++++ src/result.rs | 193 +++++++++++++++++++++++++++++++ src/variables.rs | 33 ++++++ tokens.lex | 80 +++++++++++++ 14 files changed, 1246 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 Makefile create mode 100644 grammar.y create mode 100644 main.c create mode 100644 rust.h create mode 100644 src/dbg.rs create mode 100644 src/eval/distribution.rs create mode 100644 src/eval/mod.rs create mode 100644 src/lib.rs create mode 100644 src/result.rs create mode 100644 src/variables.rs create mode 100644 tokens.lex diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d42b544 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/target +lex.yy.* +grammar.tab.* +main.o +distr \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..2f24718 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "distr" +version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..678aa94 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "distr" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +crate-type = ["staticlib"] \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3f8c5d2 --- /dev/null +++ b/Makefile @@ -0,0 +1,46 @@ +CC := gcc +LIBRARIES := -lfl -lm -lreadline + +calc: main.o grammar.tab.o lex.yy.o target/release/libdistr.a + $(CC) -o distr $^ $(LIBRARIES) + +main.o: main.c grammar.tab.h + $(CC) -c $^ + +lex.yy.o: lex.yy.c + $(CC) -c $^ + +lex.yy.c: tokens.lex grammar.tab.h + flex -o lex.yy.c $^ + +grammar.tab.o: grammar.tab.c + $(CC) -c $^ + +grammar.tab.c: grammar.y rust.h + bison -d $< + +grammar.tab.h: grammar.y rust.h + bison -d $< + +rust.h: rust + +.PHONY: rust +rust: + cargo build -r + # cbindgen -o rust.h -l c + +target/release/libdistr.a: rust + +.PHONY: clean +clean: + rm -f lex.yy.* + rm -f grammar.tab.* + rm -f main.o distr + +.PHONY: sr +sr: grammar.y + bison -d -Wcounterexamples $< + +.PHONY: lint +lint: + cargo clippy --no-deps \ No newline at end of file diff --git a/grammar.y b/grammar.y new file mode 100644 index 0000000..5d1b677 --- /dev/null +++ b/grammar.y @@ -0,0 +1,111 @@ +%{ +#include +#include "rust.h" + +#define YYSTYPE BST + +int yyerror(char*); + +extern Variables global_variables; + +%} + +%token NUMBER IDENT ASSIGN +%token PLUS MINUS MUL DIV +%token LESSER GREATER LEFT RIGHT LCURL RCURL COMMA COLON LSQUARE RSQUARE +%token END SYNTAXERROR + +%left PLUS MINUS +%left MUL DIV +%left LESSER GREATER +%left NEG POS SQRT + +%start Input +%% + +Input: + /* empty input */ + | Input Line +; + +Line: + END + // echo expressions to stdout + | Expression END { + print(bst_eval($1, &global_variables)); + } + // assignment + | IDENT ASSIGN Expression END { + insert_variable(&global_variables, $1, bst_eval($3, &global_variables)); + } + | SYNTAXERROR END { + fprintf(stderr, "syntax error\n"); + } +; + +// expressions are anything that can be evaluated +Expression: + // terminal nodes: a literal number and a variable + NUMBER { $$ = yylval; } + | IDENT { $$ = yylval; } + + // term separators + | Expression PLUS Expression { $$ = add($1, $3); } + | Expression MINUS Expression { $$ = sub($1, $3); } + + // mul/div + | Expression MUL Expression { $$ = mul($1, $3); } + | LEFT Expression RIGHT LEFT Expression RIGHT { $$ = mul($2, $5); } + | Expression DIV Expression { $$ = divide($1, $3); } + + // unary + and - + | MINUS Expression %prec NEG { $$ = neg($2); } + | PLUS Expression %prec POS { $$ = $2; } + | SQRT Expression { $$ = sqrt($2); } + + // parenthesis + | LEFT Expression RIGHT { $$ = $2; } + + // there are a few ways to make a distribution + // 1. standard deviation and mean + // { u, o } + | LCURL Expression COMMA Expression RCURL { + $$ = two_var_distr($2, $4); + } + // 2. association + // [ a: b, c: d, e: f ] + | LSQUARE AssoList RSQUARE { + $$ = $2; + } + + | Expression LESSER Expression { + $$ = less($1, $3); + } + | Expression GREATER Expression { + $$ = more($1, $3); + } +; + +AssoList: + /* empty list */ { + $$ = empty_list(); + } + | NonEmpyList { + $$ = $1; + } +; + +NonEmpyList: + Expression COLON Expression { + $$ = np_pair($1, $3); + } + | NonEmpyList COMMA Expression COLON Expression { + $$ = np_pair_push($1, $3, $5); + } +; + +%% + +int yyerror(char* msg) { + return printf("%s\n", msg); +} diff --git a/main.c b/main.c new file mode 100644 index 0000000..58a4321 --- /dev/null +++ b/main.c @@ -0,0 +1,14 @@ +#include + +#include "grammar.tab.h" +#include "rust.h" + +Variables global_variables; + +int main() { + global_variables = new_variables(); + + yyparse(); + + drop_variables(global_variables); +} \ No newline at end of file diff --git a/rust.h b/rust.h new file mode 100644 index 0000000..97cd109 --- /dev/null +++ b/rust.h @@ -0,0 +1,195 @@ +#include +#include +#include +#include + +typedef enum BinaryOperation { + Add, + Sub, + Mul, + Div, + Less, + More, +} BinaryOperation; + +typedef struct EvDistr EvDistr; + +typedef struct HashMap_String__Evaluation HashMap_String__Evaluation; + +typedef struct String String; + +typedef struct Vec_NPPair Vec_NPPair; + +typedef enum Evaluation_Tag { + Number, + MeanDev, +} Evaluation_Tag; + +typedef struct Evaluation { + Evaluation_Tag tag; + union { + struct { + double number; + }; + struct { + double mean; + double dev; + Vec_NPPair* list; + } mean_dev; + }; +} Evaluation; + +typedef enum MathResult_Evaluation_Tag { + /** + * Ok(T) is the Ok variation. Analagous to + * std::result::Result::Ok. + */ + Ok_Evaluation, + /** + * One of the referenced variables in an expression has not been + * assigned yet. + */ + VariableNotFound_Evaluation, + /** + * This is no an error per-se, but rather the result of return with + * no expression to return. This is early-returned to simulate the + * fact that any expression involving null would also return null. + */ + Void_Evaluation, + /** + * The denominator of a division was 0. + */ + ZeroDivision_Evaluation, + /** + * Bad type for operands + */ + TypeError_Evaluation, + /** + * a probability was not in the range [0, 1] + */ + OutOfRange_Evaluation, + /** + * probabilities don't sum to 1 + */ + InsufficientSum_Evaluation, + /** + * operation not supported + */ + NotImplemented_Evaluation, +} MathResult_Evaluation_Tag; + +typedef struct MathResult_Evaluation { + MathResult_Evaluation_Tag tag; + union { + struct { + struct Evaluation ok; + }; + }; +} MathResult_Evaluation; + +/** + * This is all the things a bst node can be. + * a lot of them have 2 levels of heap indirection, + * because for ffi reasons, anything on the heap + * has to be representable as a single pointer. + * only box can do that, and only when given a + * Sized type, so `Box`, not `Box` or + * just `String`. + */ +typedef enum BST_Tag { + Literal, + Variable, + Negate, + Root, + Binary, + ListDistr, + TwoVar, +} BST_Tag; + +typedef struct Negate_Body { + struct BST *operand; +} Negate_Body; + +typedef struct Binary_Body { + enum BinaryOperation op; + struct BST *left; + struct BST *right; +} Binary_Body; + +typedef struct TwoVar_Body { + struct BST *_0; + struct BST *_1; +} TwoVar_Body; + +typedef struct BST { + BST_Tag tag; + union { + struct { + double literal; + }; + struct { + struct String *variable; + }; + Negate_Body negate; + struct { + struct BST *root; + }; + Binary_Body binary; + struct { + struct Vec_NPPair *list_distr; + }; + TwoVar_Body two_var; + }; +} BST; + +typedef struct Variables { + struct HashMap_String__Evaluation *_0; +} Variables; + +struct MathResult_Evaluation bst_eval(struct BST root, const struct Variables *variables); + +struct BST add(struct BST left, struct BST right); + +struct BST sub(struct BST left, struct BST right); + +struct BST mul(struct BST left, struct BST right); + +struct BST divide(struct BST left, struct BST right); + +struct BST less(struct BST left, struct BST right); + +struct BST more(struct BST left, struct BST right); + +struct BST neg(struct BST operand); + +struct BST sqrt(struct BST operand); + +void print(struct MathResult_Evaluation p); + +/** + * # Safety + * you must be able to form a slice from `name` and `len` + */ +struct BST number(const uint8_t *name, uintptr_t len); + +/** + * # Safety + * you must be able to form a slice from `name` and `len` + */ +struct BST variable(const uint8_t *name, uintptr_t len); + +struct BST two_var_distr(struct BST mu, struct BST sig); + +struct BST empty_list(void); + +struct BST np_pair(struct BST left, struct BST right); + +struct BST np_pair_push(struct BST list, struct BST left, struct BST right); + +struct Variables new_variables(void); + +void insert_variable(struct Variables *vars, struct BST name, struct MathResult_Evaluation value); + +void drop_variables(struct Variables v); + +void debug(const struct MathResult_Evaluation *v); diff --git a/src/dbg.rs b/src/dbg.rs new file mode 100644 index 0000000..652bd0c --- /dev/null +++ b/src/dbg.rs @@ -0,0 +1,15 @@ +use crate::{result::MathResult, eval::Evaluation}; + +#[macro_export] +macro_rules! mega_dbg { + ($target:expr, $target_type:tt) => { + eprintln!("[{}:{}] \x1b[31m{:#?}\x1b[0m @ {:?}", file!(), line!(), &$target, &$target as *const $target_type) + }; +} + +#[allow(unused_parens)] +#[no_mangle] +pub extern "C" +fn debug(v: &MathResult) { + mega_dbg!(v, (&MathResult)); +} \ No newline at end of file diff --git a/src/eval/distribution.rs b/src/eval/distribution.rs new file mode 100644 index 0000000..bf5b0b3 --- /dev/null +++ b/src/eval/distribution.rs @@ -0,0 +1,77 @@ +use std::collections::HashMap; + +#[allow(clippy::box_collection)] +#[repr(C)] +#[derive(Debug, Clone)] +pub struct EvDistr { + pub mean: f64, + pub dev: f64, + pub list: Option>>, +} +impl EvDistr { + pub fn modify(self, n: f64, mut f: MeanFn, d: DevFn) -> Self + where + MeanFn: FnMut(f64, f64) -> f64, + DevFn: FnOnce(f64, f64) -> f64, + { + let list = if let Some(mut boxed_list) = self.list { + // using map in place to avoid dereferencing then boxing a value + // repeatedly. it's already inefficient to use the heap + // but necessary for ffi reasons. the least we can do is try. + map_in_place(boxed_list.as_mut(), |&(x, p)| (f(x, n), p)); + Some(boxed_list) + } else { + None + }; + + Self { + mean: f(self.mean, n), + dev: d(self.dev, n), + list, + } + } + + pub fn combine(self, rhs: Self, mut f: MeanFn) -> Self + where + MeanFn: FnMut(f64, f64) -> f64, + { + let mean = f(self.mean, rhs.mean); + let dev = (self.dev*self.dev + rhs.dev*rhs.dev).powf(0.5); + let list = match (self.list, rhs.list) { + (None, None) => None, + (Some(l), None) | (None, Some(l)) => Some(l), + (Some(left), Some(right)) => Some(Box::new(combine_lists(*left, *right, f))), + }; + + Self { + mean, dev, list + } + } +} + +#[inline] +fn map_in_place(l: &mut [T], mut f: F) +where + F: FnMut(&T) -> T +{ + l.iter_mut().for_each(|refm| *refm = f(refm)) +} + +fn combine_lists(left: Vec<(f64, f64)>, right: Vec<(f64, f64)>, mut f: F) -> Vec<(f64, f64)> +where + F: FnMut(f64, f64) -> f64, +{ + let mut map = HashMap::new(); + + for (n1, p1) in left { + for &(n2, p2) in &right { + let key: u64 = f(n1, n2).to_bits(); + let value = p1*p2; + *map.entry(key).or_insert(0.0) += value; + } + } + + map.into_iter() + .map(|(n, p)| (f64::from_bits(n), p)) + .collect() +} \ No newline at end of file diff --git a/src/eval/mod.rs b/src/eval/mod.rs new file mode 100644 index 0000000..ab434b1 --- /dev/null +++ b/src/eval/mod.rs @@ -0,0 +1,221 @@ +use std::fmt::Display; +use std::ops::{Add, Sub, Mul, Div}; + +pub mod distribution; + +pub use distribution::EvDistr; + +use crate::result::MathResult; +use MathResult::*; + +macro_rules! full_impl { + ($left:expr, $right:expr, $mean:expr) => {{ + use Evaluation as E; + match ($left, $right) { + (E::Number(n1), E::Number(n2)) => Ok(E::Number($mean(n1, n2))), + (E::Number(n),E::MeanDev(d)) => { + Ok(E::MeanDev(d.modify(n, $mean, |left, _right| left))) + }, + (E::MeanDev(d),E::Number(n)) => { + Ok(E::MeanDev(d.modify(n, $mean, |left, _right| left))) + }, + (E::MeanDev(d1), E::MeanDev(d2)) => Ok(E::MeanDev(d1.combine(d2, $mean))) + } + }}; +} + +macro_rules! partial_impl { + ($left:expr, $right:expr, $function:expr) => {{ + use Evaluation as E; + match ($left, $right) { + (Self::Number(n1),Self::Number(n2)) => Ok(Self::Number($function(n1, n2))), + (Self::Number(n),Self::MeanDev(d)) => { + Ok(E::MeanDev(d.modify(n, $function, $function))) + }, + (Self::MeanDev(d),Self::Number(n)) => { + Ok(E::MeanDev(d.modify(n, $function, $function))) + }, + (Self::MeanDev(..),Self::MeanDev(..)) => MathResult::NotImplemented, + } + }}; +} + +#[repr(C)] +#[derive(Debug, Clone)] +pub enum Evaluation { + Number(f64), + MeanDev(EvDistr), +} +impl Evaluation { + pub fn two_var(mean: f64, dev: f64) -> Self { + Self::MeanDev(EvDistr { mean, dev, list: None }) + } + + pub fn from_list(l: Vec<(Evaluation, Evaluation)>) -> MathResult { + let mut mean = 0.0; + let list = l.into_iter() + .map( + |pair| + match pair { + (Evaluation::Number(n), Evaluation::Number(p)) => { + if (0.0..=1.0).contains(&p) { + Ok((n, p)) + } else { + MathResult::OutOfRange + } + }, + _ => MathResult::TypeError, + } + ).collect::< MathResult< Vec<(f64, f64)> > >()?; + + let mut p_sum = 0.0; + for &(n, p) in &list { + mean += n*p; + p_sum += p; + } + if p_sum != 1.0 { + return MathResult::InsufficientSum; + } + + // sqrt(sum( p*(x_i - u)^2 )) + let dev = list.iter() + .map(|&(n, p)| (n - mean).powi(2) * p) + .sum::() + .powf(0.5); + + Ok(Evaluation::MeanDev(EvDistr { mean, dev, list: Some(Box::new(list)) })) + } + + pub fn add(self, rhs: Self) -> MathResult { + full_impl!(self, rhs, f64::add) + } + + pub fn sub(self, rhs: Self) -> MathResult { + full_impl!(self, rhs, f64::sub) + } + + pub fn mul(self, rhs: Self) -> MathResult { + partial_impl!(self, rhs, f64::mul) + } + + pub fn div(self, rhs: Self) -> MathResult { + partial_impl!(self, rhs, f64::div) + } + + // always returns the Number variant + pub fn p_less(self, rhs: Self) -> MathResult { + // p_unchecked is made by the imperfect approximation + // of the integral. it is not guaranteed to be in the + // range of [0, 1] + let p_unchecked = match (self, rhs) { + (Self::Number(n1), Self::Number(n2)) => { + if n1 < n2 { 1.0 } + else { 0.0 } + }, + (Self::Number(n), Self::MeanDev(d)) => { + more(d, n) + }, + (Self::MeanDev(d), Self::Number(n)) => { + less(d, n) + }, + (this @ Self::MeanDev(..), other @ Self::MeanDev(..)) => { + // P(self < rhs) = P(self-rhs < 0) + return this.sub(other)?.p_less(Evaluation::Number(0.0)); + } + }; + + // p clamps p_unchecked to [0, 1] + // this happens when the approximation gets too inaccurate, and + // reads off massive floating point values. + // for example {3, 5} < 100 can return numbers with dozens of digits + let p = if p_unchecked > 1.0 { + 1.0 + } else if p_unchecked < 0.0 { + 0.0 + } else { + p_unchecked + }; + + Ok(Self::Number(p)) + } + + // always returns the Number variant + pub fn p_more(self, rhs: Self) -> MathResult { + rhs.p_less(self) + } +} +impl Display for Evaluation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Number(n) => write!(f, "{n}"), + Self::MeanDev(d) => write!(f, "μ = {}, σ = {:.3}", d.mean, d.dev), + } + } +} + +fn less(d: EvDistr, n: f64) -> f64 { + match d.list { + None => less_continuous(d.mean, d.dev, n), + Some(l) => { + l.into_iter() + .filter(|pair| pair.0 < n) + .map(|pair| pair.1) + .sum() + } + } +} + +fn more(d: EvDistr, n: f64) -> f64 { + match d.list { + None => 1.0 - less_continuous(d.mean, d.dev, n), + Some(l) => { + l.into_iter() + .filter(|pair| pair.0 > n) + .map(|pair| pair.1) + .sum() + } + } +} + +/// Return the probability that a distribution with mean +/// `mean` and standard deviation `stddev` is less than `n` +fn less_continuous(mean: f64, stddev: f64, x: f64) -> f64 { + /* + The equation for the normal distribution is + + exp( -(x-u)^2 / 2o^2 ) / o / sqrt(2 * pi) + + This can't actually be integrated. + + However! You can approximate e^x with a taylor series + + so when you divide again by the o * sqrt(2pi) term, you can approximate + integrals! + */ + + const TERMS: usize = 50; + const FRAC_1_SQRT_2_PI: f64 = std::f64::consts::FRAC_1_SQRT_PI * std::f64::consts::FRAC_1_SQRT_2; + let coefficient = FRAC_1_SQRT_2_PI/stddev; + + let mut result = 0.0; + + // diff is the difference, diff_squared is the square of the difference + // done by expanding (x-mean)^2 + let diff = x-mean; + let diff_squared = x*x - 2.0*mean*x + mean*mean; + let twice_variance = 2.0 * stddev * stddev; + let term_coeff = diff_squared/twice_variance; + + for n in 0..=TERMS { + // this loop ends adding a ton of very small numbers. minimizing + // floating point errors is a must + let mut term = diff / (2*n + 1) as f64; + for fact in 1..=n { term *= term_coeff/fact as f64 } + if n%2 == 1 { term *= -1.0 } + result += term; + } + + // the improper integral subtracts the value of negative infinity, + // which is -0.5 by definition + coefficient * result + 0.5 +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..dbaa1f5 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,240 @@ +#![feature(try_trait_v2, more_float_constants)] + +mod result; +mod eval; +mod variables; +mod dbg; + +use crate::result::{MathResult, IntoMath}; +use MathResult::*; + +/// tuples are frustrating over ffi. +/// use a struct instead +#[repr(C)] +#[derive(Debug, Clone)] +pub struct NPPair { + n: BST, + p: BST, +} + +use eval::Evaluation; +use variables::Variables; + +/// This is all the things a bst node can be. +/// a lot of them have 2 levels of heap indirection, +/// because for ffi reasons, anything on the heap +/// has to be representable as a single pointer. +/// only box can do that, and only when given a +/// Sized type, so `Box`, not `Box` or +/// just `String`. +#[repr(C)] +#[derive(Debug, Clone)] +pub enum BST { + // valid/passthrough nodes + Literal(f64), + Variable(Box), + + // operations + Negate { + operand: Box, + }, + Root(Box), + + Binary { + op: BinaryOperation, + left: Box, + right: Box, + }, + + // values passed from bison that must + // be converted + ListDistr(Box>), + TwoVar(Box, Box), +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub enum BinaryOperation { + Add, + Sub, + Mul, + Div, + Less, + More, +} +impl BinaryOperation { + fn eval(self, left: Evaluation, right: Evaluation) -> MathResult { + match self { + Self::Add => left.add(right), + Self::Sub => left.sub(right), + Self::Mul => left.mul(right), + Self::Div => left.div(right), + Self::Less => left.p_less(right), + Self::More => left.p_more(right), + } + } +} + +#[no_mangle] +pub extern "C" +fn bst_eval(root: BST, variables: &Variables) -> MathResult { + match root { + BST::Literal(n) => Ok(Evaluation::Number(n)), + + BST::Binary { op, left, right } => { + let left = bst_eval(*left, variables)?; + let right = bst_eval(*right, variables)?; + op.eval(left, right) + }, + BST::Negate { operand } => { + let operand = bst_eval(*operand, variables)?; + let operation = BinaryOperation::Mul; + let factor = Evaluation::Number(-1.0); + operation.eval(factor, operand) + }, + BST::Root(operand) => { + let intermediate = bst_eval(*operand, variables)?; + match intermediate { + Evaluation::Number(n) => MathResult::Ok(Evaluation::Number(n.sqrt())), + Evaluation::MeanDev(..) => MathResult::NotImplemented, + } + } + + BST::Variable(name) => variables.0 + .get(Box::as_ref(&name)) + .cloned() + .into_math(MathResult::VariableNotFound), + + BST::TwoVar(mean, dev) => { + let Evaluation::Number(u) = bst_eval(*mean, variables)? + else { return MathResult::TypeError }; + let Evaluation::Number(o) = bst_eval(*dev, variables)? + else { return MathResult::TypeError }; + Ok(Evaluation::two_var(u, o)) + }, + BST::ListDistr(list) => { + let list = (*list).into_iter() + .map( + |pair| + Ok((bst_eval(pair.n, variables)?, bst_eval(pair.p, variables)?)) + ) + .collect::< MathResult> >()?; + Evaluation::from_list(list) + }, + } +} + +macro_rules! binary { + ($variant:ident, $left:ident, $right:ident) => { + BST::Binary { + op: BinaryOperation::$variant, + left: Box::new($left), + right: Box::new($right), + } + }; +} + +#[no_mangle] +pub extern "C" +fn add(left: BST, right: BST) -> BST { + binary!(Add, left, right) +} + +#[no_mangle] +pub extern "C" +fn sub(left: BST, right: BST) -> BST { + binary!(Sub, left, right) +} + +#[no_mangle] +pub extern "C" +fn mul(left: BST, right: BST) -> BST { + binary!(Mul, left, right) +} + +#[no_mangle] +pub extern "C" +fn divide(left: BST, right: BST) -> BST { + binary!(Div, left, right) +} + +#[no_mangle] +pub extern "C" +fn less(left: BST, right: BST) -> BST { + binary!(Less, left, right) +} + +#[no_mangle] +pub extern "C" +fn more(left: BST, right: BST) -> BST { + binary!(More, left, right) +} + +#[no_mangle] +pub extern "C" +fn neg(operand: BST) -> BST { + BST::Negate { operand: Box::new(operand) } +} + +#[no_mangle] +pub extern "C" +fn sqrt(operand: BST) -> BST { + BST::Root(Box::new(operand)) +} + +#[no_mangle] +pub extern "C" +fn print(p: MathResult) { + match p { + Ok(n) => println!(">>> {n}"), + other => println!("{other}"), + } +} + +/// # Safety +/// you must be able to form a slice from `name` and `len` +#[no_mangle] +pub unsafe extern "C" +fn number(name: *const u8, len: usize) -> BST { + let slice = unsafe { std::slice::from_raw_parts(name, len) }; + let str = unsafe { std::str::from_utf8_unchecked(slice) }; + BST::Literal(unsafe { str.parse().unwrap_unchecked() }) +} + +/// # Safety +/// you must be able to form a slice from `name` and `len` +#[no_mangle] +pub unsafe extern "C" +fn variable(name: *const u8, len: usize) -> BST { + let slice = unsafe { std::slice::from_raw_parts(name, len) }; + let str = unsafe { std::str::from_utf8_unchecked(slice) }; + let string = str.to_owned(); + BST::Variable(Box::new(string)) +} + +#[no_mangle] +pub extern "C" +fn two_var_distr(mu: BST, sig: BST) -> BST { + BST::TwoVar(Box::new(mu), Box::new(sig)) +} + +#[no_mangle] +pub extern "C" +fn empty_list() -> BST { + BST::ListDistr(Box::default()) +} + +#[no_mangle] +pub extern "C" +fn np_pair(left: BST, right: BST) -> BST { + BST::ListDistr(Box::new(vec![ NPPair { n: left, p: right } ])) +} + +#[no_mangle] +pub extern "C" +fn np_pair_push(list: BST, left: BST, right: BST) -> BST { + let BST::ListDistr(mut l) = list else { unsafe { std::hint::unreachable_unchecked() } }; + l.push( NPPair { n: left, p: right } ); + BST::ListDistr(l) +} + diff --git a/src/result.rs b/src/result.rs new file mode 100644 index 0000000..db20db0 --- /dev/null +++ b/src/result.rs @@ -0,0 +1,193 @@ +use std::convert::Infallible; +use std::hint::unreachable_unchecked; +/// Define the FFI-safe MathResult type +/// this would be a regular type alias, but +/// Result isn't #[repr(C)], and +/// therefore not FFI-safe + +use std::ops::{FromResidual, Try, ControlFlow}; +use std::fmt::Display; + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub enum MathResult { + /// Ok(T) is the Ok variation. Analagous to + /// std::result::Result::Ok. + Ok(T), + /// One of the referenced variables in an expression has not been + /// assigned yet. + VariableNotFound, + /// This is no an error per-se, but rather the result of return with + /// no expression to return. This is early-returned to simulate the + /// fact that any expression involving null would also return null. + Void, + /// The denominator of a division was 0. + ZeroDivision, + /// Bad type for operands + TypeError, + /// a probability was not in the range [0, 1] + OutOfRange, + /// probabilities don't sum to 1 + InsufficientSum, + /// operation not supported + NotImplemented, +} + +impl FromResidual> for MathResult { + /// This implementation of residual is a bit weird. The residual + /// is of type `MathResult` because the error variants are effectively + /// untyped, in that they do not have any generic or concrete types. + /// # Undefined behaviour + /// This function assumes that it will only be called by Rust to + /// reconstitute a `MathResult` from one of its error variants, and + /// as such it is **Undefined Behaviour** to call this on a `MathResult::Ok` + fn from_residual(this: MathResult) -> Self { + match this { + MathResult::Ok(_) => unsafe { unreachable_unchecked() }, + MathResult::VariableNotFound => Self::VariableNotFound, + MathResult::Void => Self::Void, + MathResult::ZeroDivision => Self::ZeroDivision, + MathResult::TypeError => Self::TypeError, + MathResult::InsufficientSum => Self::InsufficientSum, + MathResult::OutOfRange => Self::OutOfRange, + MathResult::NotImplemented => Self::NotImplemented, + } + } +} + +impl Try for MathResult { + type Output = T; + type Residual = Self; + + fn from_output(data: T) -> Self { + Self::Ok(data) + } + + fn branch(self) -> ControlFlow { + // this is a weird match, but essentialy the Ok data can + // get `Continue`d, while others get `Break`ed. + match self { + MathResult::Ok(data) => ControlFlow::Continue(data), + other => ControlFlow::Break(other), + } + } +} + +/// For a given Result>, convert into a +/// MathResult by turning Ok(t) into MathResult::Ok(t) and Err(n) +/// into Self::n. E in the Result must be of type MathResult +/// to avoid the possibility of Err(Ok(_)). +impl From< Result< OkType, MathResult > > for MathResult { + fn from(value: Result< OkType, MathResult >) -> MathResult { + match value { + Ok(data) => Self::Ok(data), + // this is unreachable because this is Err(Ok(Infalliable)) + Err(MathResult::Ok(_)) => unsafe { unreachable_unchecked() }, + Err(MathResult::VariableNotFound) => Self::VariableNotFound, + Err(MathResult::ZeroDivision) => Self::ZeroDivision, + Err(MathResult::Void) => Self::Void, + Err(MathResult::TypeError) => Self::TypeError, + Err(MathResult::OutOfRange) => Self::OutOfRange, + Err(MathResult::InsufficientSum) => Self::InsufficientSum, + Err(MathResult::NotImplemented) => Self::NotImplemented, + } + } +} + +/// For a given MathResult, turn it into a +/// Result by mapping Self::Ok(data) to +/// Ok(data), and any error value to Err(value) +impl From< MathResult > for Result> { + fn from(val: MathResult) -> Self { + match val { + MathResult::Ok(data) => Ok(data), + MathResult::VariableNotFound => Err(MathResult::VariableNotFound), + MathResult::ZeroDivision => Err(MathResult::ZeroDivision), + MathResult::Void => Err(MathResult::Void), + MathResult::TypeError => Err(MathResult::TypeError), + MathResult::OutOfRange => Err(MathResult::OutOfRange) , + MathResult::InsufficientSum => Err(MathResult::InsufficientSum), + MathResult::NotImplemented => Err(MathResult::NotImplemented), + } + } +} + +impl +FromIterator< MathResult > for MathResult +where + Collection: FromIterator +{ + fn from_iter(iter: I) -> MathResult + where + I: IntoIterator> + { + let result_iterator = iter.into_iter().map( + // this hellish generic selects the Into implementation that converts + // a MathResult into a Result>, so that I + // can use Result's implementation of FromIterator later. + as Into< Result> >>::into + ); + // Use Result's implementation of FromIterator to build up the + // collection as a Result. + let collected_result: Result> = result_iterator.collect(); + // convert the result into a MathResult. The Err variant is + // of type MathResult so this is a flawless conversion. + collected_result.into() + } +} + + +impl Display for MathResult +where T: Display { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // pretty self-explanatory. write the right message for each variant + match self { + Self::Ok(data) => write!(f, "{data}"), + Self::VariableNotFound => write!(f, "variable not found"), + Self::ZeroDivision => write!(f, "division by zero"), + Self::Void => write!(f, "(void)"), + Self::TypeError => write!(f, "bad type"), + Self::InsufficientSum => write!(f, "probabilities must sum to 1"), + Self::OutOfRange => write!(f, "probabilities must be in the range [0, 1]"), + Self::NotImplemented => write!(f, "operation not supported"), + } + } +} + +mod seal_mod { + /// This is just a sanity check to make sure I never try to ipmlement + /// IntoMath elsewhere in the codebase. + pub trait Sealed {} +} +/// Ugly hack to simplify moving between rust's error-handling types `Option` +/// and `Result` and my `MathResult`. +/// This is `Sealed` for sanity reasons. +pub trait IntoMath: seal_mod::Sealed { + type Target; + + fn into_math(self, fallback: MathResult) -> MathResult; +} + +impl seal_mod::Sealed for Option {} +impl IntoMath for Option { + type Target = T; + + fn into_math(self, fallback: MathResult) -> MathResult { + match self { + Some(data) => MathResult::Ok(data), + None => fallback, + } + } +} + +impl seal_mod::Sealed for Result {} +impl IntoMath for Result { + type Target = T; + + fn into_math(self, fallback: MathResult) -> MathResult { + match self { + Ok(data) => MathResult::Ok(data), + Err(_) => fallback, + } + } +} \ No newline at end of file diff --git a/src/variables.rs b/src/variables.rs new file mode 100644 index 0000000..2458028 --- /dev/null +++ b/src/variables.rs @@ -0,0 +1,33 @@ +use std::collections::HashMap; +use std::hint::unreachable_unchecked; + +use crate::{eval::Evaluation, BST}; +use crate::result::MathResult; + +#[allow(clippy::box_collection)] +#[repr(C)] +#[derive(Debug, Clone, Default)] +pub struct Variables(pub Box>); + +#[no_mangle] +pub extern "C" +fn new_variables() -> Variables { + Variables::default() +} + +#[no_mangle] +pub extern "C" +fn insert_variable(vars: &mut Variables, name: BST, value: MathResult) { + let MathResult::Ok(v) = value else { + eprintln!("{value}"); + return + }; + let BST::Variable(k) = name else {unsafe { unreachable_unchecked() }}; + vars.0.insert(*k, v); +} + +#[no_mangle] +pub extern "C" +fn drop_variables(v: Variables) { + drop(v) +} \ No newline at end of file diff --git a/tokens.lex b/tokens.lex new file mode 100644 index 0000000..6d88202 --- /dev/null +++ b/tokens.lex @@ -0,0 +1,80 @@ +%{ +#include "rust.h" +#define YYSTYPE BST +#include "grammar.tab.h" // PLUS MINUS etc. + +extern YYSTYPE yylval; + +#include +#include +#define YY_INPUT(buf,result,max_size) result = mygetinput(buf, max_size); + +static int mygetinput(char *buf, int size) { + char *line; + + if (feof(yyin)) + return YY_NULL; + + line = readline("> "); + + if(line == NULL) + return YY_NULL; + + size_t len = strnlen(line, size); + if (len > size-2) + return YY_NULL; + + snprintf(buf, size, "%s\n", line); + + add_history(line); + free(line); + + return strlen(buf); +} + +%} + +%option noyywrap + +white [ \t]+ +sign (\+?|-) +digits [0-9]+ +number {digits}(\.{digits}|e{sign}{digits})? +ident [a-zA-Z][a-zA-Z0-9_]* + +%% + +{white} { ; } + +{number} { + yylval = number(yytext, yyleng); + return NUMBER; +} + +{ident} { + yylval = variable(yytext, yyleng); + return IDENT; +} + +"+" return PLUS; +"-" return MINUS; +"*" return MUL; +"/" return DIV; +"=" return ASSIGN; +"(" return LEFT; +")" return RIGHT; +"{" return LCURL; +"}" return RCURL; +"[" return LSQUARE; +"]" return RSQUARE; +"," return COMMA; +":" return COLON; +"<" return LESSER; +">" return GREATER; +"~" return SQRT; + +"\n" return END; + +. return SYNTAXERROR; + +%%