345 lines
11 KiB
Rust
345 lines
11 KiB
Rust
use std::{collections::{HashMap, HashSet}, fmt::Write};
|
|
|
|
use crate::complex::{Complex, cxfn};
|
|
|
|
use super::parser::{Expr, Stmt, UnaryOp, BinaryOp, Defn};
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub enum CompileError<'s> {
|
|
FmtError,
|
|
TypeError(&'s str),
|
|
UndefinedVar(&'s str),
|
|
Reassignment(&'s str),
|
|
}
|
|
|
|
impl <'s> From<std::fmt::Error> for CompileError<'s> {
|
|
fn from(_: std::fmt::Error) -> Self {
|
|
Self::FmtError
|
|
}
|
|
}
|
|
|
|
thread_local! {
|
|
pub static BUILTINS: HashMap<&'static str, (&'static str, Type, Option<cxfn::Function>)> = {
|
|
let mut m: HashMap<&'static str, (&'static str, Type, Option<cxfn::Function>)> = HashMap::new();
|
|
m.insert("i", ("CONST_I", Type::Number, Some(|_| Complex::new(0.0, 1.0))));
|
|
m.insert("e", ("CONST_E", Type::Number, Some(|_| Complex::new(std::f64::consts::E, 0.0))));
|
|
m.insert("tau", ("CONST_TAU",Type::Number, Some(|_| Complex::new(std::f64::consts::TAU, 0.0))));
|
|
m.insert("re", ("c_re", Type::Function(1), Some(cxfn::re)));
|
|
m.insert("im", ("c_im", Type::Function(1), Some(cxfn::im)));
|
|
m.insert("conj", ("c_conj", Type::Function(1), Some(cxfn::conj)));
|
|
m.insert("abs_sq", ("c_abs_sq", Type::Function(1), Some(cxfn::abs_sq)));
|
|
m.insert("abs", ("c_abs", Type::Function(1), Some(cxfn::abs)));
|
|
m.insert("arg", ("c_arg", Type::Function(1), Some(cxfn::arg)));
|
|
m.insert("pos", ("c_pos", Type::Function(1), Some(cxfn::pos)));
|
|
m.insert("neg", ("c_neg", Type::Function(1), Some(cxfn::neg)));
|
|
m.insert("recip", ("c_recip", Type::Function(1), Some(cxfn::recip)));
|
|
m.insert("add", ("c_add", Type::Function(2), Some(cxfn::add)));
|
|
m.insert("sub", ("c_sub", Type::Function(2), Some(cxfn::sub)));
|
|
m.insert("mul", ("c_mul", Type::Function(2), Some(cxfn::mul)));
|
|
m.insert("div", ("c_div", Type::Function(2), Some(cxfn::div)));
|
|
m.insert("recip", ("c_recip", Type::Function(1), Some(cxfn::recip)));
|
|
m.insert("exp", ("c_exp", Type::Function(1), Some(cxfn::exp)));
|
|
m.insert("log", ("c_log", Type::Function(1), Some(cxfn::log)));
|
|
m.insert("sqrt", ("c_sqrt", Type::Function(1), Some(cxfn::sqrt)));
|
|
m.insert("sin", ("c_sin", Type::Function(1), Some(cxfn::sin)));
|
|
m.insert("cos", ("c_cos", Type::Function(1), Some(cxfn::cos)));
|
|
m.insert("tan", ("c_tan", Type::Function(1), Some(cxfn::tan)));
|
|
m.insert("sinh", ("c_sinh", Type::Function(1), Some(cxfn::sinh)));
|
|
m.insert("cosh", ("c_cosh", Type::Function(1), Some(cxfn::cosh)));
|
|
m.insert("tanh", ("c_tanh", Type::Function(1), Some(cxfn::tanh)));
|
|
m.insert("gamma", ("c_gamma", Type::Function(1), None));
|
|
m
|
|
};
|
|
}
|
|
|
|
#[derive(Clone, Copy)]
|
|
pub enum Type {
|
|
Number,
|
|
Function(u32),
|
|
}
|
|
|
|
#[derive(Clone, Copy)]
|
|
enum NameScope {
|
|
Local, Global, Builtin
|
|
}
|
|
|
|
struct NameInfo<'s> {
|
|
scope: NameScope,
|
|
name: &'s str,
|
|
ty: Type
|
|
}
|
|
|
|
impl <'s> NameInfo<'s> {
|
|
pub fn get_cname(&self) -> String {
|
|
let name = self.name;
|
|
match (self.scope, self.ty) {
|
|
(NameScope::Local, _) => format!("arg_{name}"),
|
|
(NameScope::Global, Type::Number) => format!("VAR_{name}"),
|
|
(NameScope::Global, Type::Function(_)) => format!("func_{name}"),
|
|
(NameScope::Builtin, _) => name.to_owned(),
|
|
}
|
|
}
|
|
}
|
|
|
|
type LocalTable<'s> = HashSet<&'s str>;
|
|
|
|
type CompileResult<'s,T=()> = Result<T, CompileError<'s>>;
|
|
|
|
pub fn compile<'s>(buf: &mut impl Write, defns: &[Defn<'s>]) -> CompileResult<'s> {
|
|
let mut compiler = Compiler::new(buf);
|
|
for defn in defns {
|
|
compiler.compile_defn(defn)?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
struct Compiler<'s, 'w, W> where W: Write {
|
|
buf: &'w mut W,
|
|
globals: HashMap<&'s str, Type>,
|
|
}
|
|
|
|
impl <'s, 'w, W: Write> Compiler<'s, 'w, W> {
|
|
fn new(buf: &'w mut W) -> Self {
|
|
Self {
|
|
buf,
|
|
globals: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
//////////////////
|
|
// Statements //
|
|
//////////////////
|
|
|
|
fn compile_defn(&mut self, defn: &Defn<'s>) -> CompileResult<'s> {
|
|
let res = match defn {
|
|
Defn::Const { name, body } => self.defn_const(name, body),
|
|
Defn::Func { name, args, body } => self.defn_func(name, args, body),
|
|
};
|
|
writeln!(self.buf)?;
|
|
res
|
|
}
|
|
|
|
fn defn_const(&mut self, name: &'s str, body: &Expr<'s>) -> CompileResult<'s> {
|
|
if self.name_info(name, None).is_some() {
|
|
return Err(CompileError::Reassignment(name))
|
|
}
|
|
|
|
self.globals.insert(name, Type::Number);
|
|
write!(self.buf, "const VAR_{name} = ")?;
|
|
|
|
let locals = LocalTable::with_capacity(0);
|
|
self.compile_expr(&locals, body)?;
|
|
|
|
write!(self.buf, ";")?;
|
|
Ok(())
|
|
}
|
|
|
|
fn defn_func(&mut self, name: &'s str, args: &[&'s str], body: &(Vec<Stmt<'s>>, Expr<'s>)) -> CompileResult<'s> {
|
|
if self.name_info(name, None).is_some() {
|
|
return Err(CompileError::Reassignment(name))
|
|
}
|
|
|
|
self.globals.insert(name, Type::Function(args.len() as u32));
|
|
write!(self.buf, "fn func_{name}(")?;
|
|
|
|
let mut locals = LocalTable::with_capacity(args.len());
|
|
for arg in args {
|
|
write!(self.buf, "arg_{arg}:vec2f,")?;
|
|
locals.insert(arg);
|
|
}
|
|
write!(self.buf, ")->vec2f{{return ")?;
|
|
self.compile_expr(&locals, &body.1)?;
|
|
write!(self.buf, ";}}")?;
|
|
Ok(())
|
|
}
|
|
|
|
fn stmt_deriv(&mut self, name: &'s str, func: &'s str) -> CompileResult<'s> {
|
|
if self.name_info(name, None).is_some() {
|
|
return Err(CompileError::Reassignment(name))
|
|
}
|
|
|
|
let Some(name_info) = self.name_info(func, None) else {
|
|
return Err(CompileError::UndefinedVar(name))
|
|
};
|
|
let Type::Function(argc) = name_info.ty else {
|
|
return Err(CompileError::TypeError(name))
|
|
};
|
|
|
|
let func_name = name_info.get_cname();
|
|
|
|
self.globals.insert(name, Type::Function(argc));
|
|
|
|
write!(self.buf, "fn func_{name}(")?;
|
|
|
|
for i in 0..argc {
|
|
write!(self.buf, "arg_{i}:vec2f,")?;
|
|
}
|
|
let args: String = (1..argc).map(|i| format!(",arg_{i}")).collect();
|
|
|
|
write!(self.buf, ")->vec2f{{\
|
|
let a = c_mul({func_name}(arg_0 + vec2( D_EPS, 0.0){args}), vec2( 0.25/D_EPS, 0.0));\
|
|
let b = c_mul({func_name}(arg_0 + vec2(-D_EPS, 0.0){args}), vec2(-0.25/D_EPS, 0.0));\
|
|
let c = c_mul({func_name}(arg_0 + vec2(0.0, D_EPS){args}), vec2(0.0, -0.25/D_EPS));\
|
|
let d = c_mul({func_name}(arg_0 + vec2(0.0, -D_EPS){args}), vec2(0.0, 0.25/D_EPS));\
|
|
return a + b + c + d;}}\
|
|
")?;
|
|
Ok(())
|
|
}
|
|
|
|
fn stmt_iter(&mut self, name: &'s str, func: &'s str, count: u32) -> CompileResult<'s> {
|
|
if self.name_info(name, None).is_some() {
|
|
return Err(CompileError::Reassignment(name))
|
|
}
|
|
|
|
let Some(name_info) = self.name_info(func, None) else {
|
|
return Err(CompileError::UndefinedVar(name))
|
|
};
|
|
let Type::Function(argc) = name_info.ty else {
|
|
return Err(CompileError::TypeError(name))
|
|
};
|
|
|
|
let func_name = name_info.get_cname();
|
|
|
|
self.globals.insert(name, Type::Function(argc));
|
|
|
|
write!(self.buf, "fn func_{name}(")?;
|
|
|
|
for i in 0..argc {
|
|
write!(self.buf, "arg_{i}:vec2f,")?;
|
|
}
|
|
let args: String = (1..argc).map(|i| format!(",arg_{i}")).collect();
|
|
|
|
write!(self.buf, ")->vec2f{{\
|
|
var r=arg_0;\
|
|
for(var i=0;i<{count};i++){{\
|
|
r={func_name}(r{args});\
|
|
}}\
|
|
return r;}}\
|
|
")?;
|
|
Ok(())
|
|
}
|
|
|
|
///////////////////
|
|
// Expressions //
|
|
///////////////////
|
|
|
|
fn compile_expr(&mut self, locals: &LocalTable<'s>, expr: &Expr<'s>) -> CompileResult<'s> {
|
|
match expr {
|
|
Expr::Number(z) => self.expr_number(*z),
|
|
Expr::Name(name) => self.expr_var(locals, name),
|
|
Expr::Unary(op, arg) => self.expr_unary(locals, *op, arg),
|
|
Expr::Binary(op, lhs, rhs) => self.expr_binary(locals, *op, lhs, rhs),
|
|
Expr::FnCall(name, args) => self.expr_fncall(locals, name, args),
|
|
}
|
|
}
|
|
|
|
fn expr_number(&mut self, z: Complex) -> CompileResult<'s> {
|
|
write!(self.buf, "vec2f({:?},{:?})", z.re, z.im)?;
|
|
Ok(())
|
|
}
|
|
|
|
fn expr_unary(&mut self, locals: &LocalTable<'s>, op: UnaryOp, arg: &Expr<'s>) -> CompileResult<'s> {
|
|
let strings = unop_strings(op);
|
|
write!(self.buf, "{}", strings[0])?;
|
|
self.compile_expr(locals, arg)?;
|
|
write!(self.buf, "{}", strings[1])?;
|
|
Ok(())
|
|
}
|
|
|
|
fn expr_binary(&mut self, locals: &LocalTable<'s>, op: BinaryOp, lhs: &Expr<'s>, rhs: &Expr<'s>) -> CompileResult<'s> {
|
|
let strings = binop_strings(op);
|
|
write!(self.buf, "{}", strings[0])?;
|
|
self.compile_expr(locals, lhs)?;
|
|
write!(self.buf, "{}", strings[1])?;
|
|
self.compile_expr(locals, rhs)?;
|
|
write!(self.buf, "{}", strings[2])?;
|
|
Ok(())
|
|
}
|
|
|
|
fn expr_var(&mut self, locals: &LocalTable<'s>, name: &'s str) -> CompileResult<'s> {
|
|
let Some(name_info) = self.name_info(name, Some(locals)) else {
|
|
return Err(CompileError::UndefinedVar(name))
|
|
};
|
|
if !matches!(name_info.ty, Type::Number) {
|
|
return Err(CompileError::TypeError(name))
|
|
}
|
|
write!(self.buf, "{}", name_info.get_cname())?;
|
|
Ok(())
|
|
}
|
|
|
|
fn expr_fncall(&mut self, locals: &LocalTable<'s>, name: &'s str, args: &Vec<Expr<'s>>) -> CompileResult<'s> {
|
|
let Some(name_info) = self.name_info(name, Some(locals)) else {
|
|
return Err(CompileError::UndefinedVar(name))
|
|
};
|
|
if !matches!(name_info.ty, Type::Function(n) if n as usize == args.len()) {
|
|
return Err(CompileError::TypeError(name))
|
|
}
|
|
write!(self.buf, "{}", name_info.get_cname())?;
|
|
write!(self.buf, "(")?;
|
|
for arg in args {
|
|
self.compile_expr(locals, arg)?;
|
|
write!(self.buf, ",")?;
|
|
}
|
|
write!(self.buf, ")")?;
|
|
Ok(())
|
|
}
|
|
|
|
/////////////
|
|
// Names //
|
|
/////////////
|
|
|
|
fn name_info(&self, name: &'s str, locals: Option<&LocalTable<'s>>) -> Option<NameInfo> {
|
|
if let Some(locals) = locals {
|
|
if locals.contains(name) {
|
|
return Some(NameInfo { scope: NameScope::Local, name, ty: Type::Number });
|
|
}
|
|
}
|
|
if let Some(ty) = self.globals.get(name).copied() {
|
|
return Some(NameInfo { scope: NameScope::Global, name, ty })
|
|
}
|
|
if let Some((bname, ty, _)) = BUILTINS.with(|m| m.get(name).copied()) {
|
|
return Some(NameInfo { scope: NameScope::Builtin, name: bname, ty })
|
|
}
|
|
None
|
|
}
|
|
|
|
// fn generate_iter(&mut self, argc: u32) -> CompileResult<'s> {
|
|
// if !self.generate.contains_key(&format!("invoke{argc}")) {
|
|
// self.generate.insert(format!("invoke{argc}"), Generate::Iter { argc });
|
|
// self.generate_invoke(argc)?;
|
|
// writeln!(self.buf)?;
|
|
// }
|
|
// write!(self.buf, "fn iter{argc}(func:vec2f,n:vec2f,")?;
|
|
// for i in 0..argc {
|
|
// write!(self.buf, "arg_{i}:vec2f")?;
|
|
// write!(self.buf, ",")?;
|
|
// }
|
|
// write!(self.buf, ")->vec2f{{var result=arg_0;")?;
|
|
// write!(self.buf, "for(var i=0;i<i32(n.x);i++){{result=invoke{argc}(func,result,")?;
|
|
// for i in 1..argc {
|
|
// write!(self.buf, "arg_{i},")?;
|
|
// }
|
|
// write!(self.buf, ");}}return result;}}")?;
|
|
// Ok(())
|
|
// }
|
|
}
|
|
|
|
|
|
const fn unop_strings(op: UnaryOp) -> [&'static str; 2] {
|
|
match op {
|
|
UnaryOp::Pos => ["+(", ")"],
|
|
UnaryOp::Neg => ["-(", ")"],
|
|
UnaryOp::Recip => ["c_recip(", ")"],
|
|
}
|
|
}
|
|
|
|
const fn binop_strings(op: BinaryOp) -> [&'static str; 3] {
|
|
match op {
|
|
BinaryOp::Add => ["(", ")+(", ")"],
|
|
BinaryOp::Sub => ["(", ")-(", ")"],
|
|
BinaryOp::Mul => ["c_mul(", ",", ")"],
|
|
BinaryOp::Div => ["c_div(", ",", ")"],
|
|
BinaryOp::Pow => ["c_pow(", ",", ")"],
|
|
}
|
|
}
|