81 lines
2.4 KiB
Rust
81 lines
2.4 KiB
Rust
use crate::complex::{Complex, cxfn};
|
|
|
|
use super::{parser::{Expr, UnaryOp, BinaryOp, Stmt, Defn}, compiler::{BUILTINS, Type}};
|
|
|
|
fn apply_unary(op: UnaryOp, arg: Complex) -> Complex {
|
|
match op {
|
|
UnaryOp::Pos => cxfn::pos(&[arg]),
|
|
UnaryOp::Neg => cxfn::neg(&[arg]),
|
|
UnaryOp::Recip => cxfn::recip(&[arg]),
|
|
}
|
|
}
|
|
|
|
fn apply_binary(op: BinaryOp, u: Complex, v: Complex) -> Complex {
|
|
match op {
|
|
BinaryOp::Add => cxfn::add(&[u, v]),
|
|
BinaryOp::Sub => cxfn::sub(&[u, v]),
|
|
BinaryOp::Mul => cxfn::mul(&[u, v]),
|
|
BinaryOp::Div => cxfn::div(&[u, v]),
|
|
BinaryOp::Pow => cxfn::pow(&[u, v]),
|
|
}
|
|
}
|
|
|
|
pub fn optimize(defns: Vec<Defn>) -> Vec<Defn> {
|
|
defns.into_iter().map(|s| match s {
|
|
Defn::Const { name, body } => Defn::Const { name, body: optimize_expr(body) },
|
|
Defn::Func { name, args, body } => {
|
|
let stmts = body.0.into_iter()
|
|
.map(optimize_stmt).collect();
|
|
let expr = optimize_expr(body.1);
|
|
Defn::Func { name, args, body: (stmts, expr) }
|
|
}
|
|
_ => s,
|
|
}).collect()
|
|
}
|
|
|
|
fn optimize_stmt(stmt: Stmt) -> Stmt {
|
|
match stmt {
|
|
Stmt::Expr(e) => Stmt::Expr(optimize_expr(e)),
|
|
Stmt::Store(v, e) => Stmt::Store(v, optimize_expr(e)),
|
|
Stmt::Iter(v, min, max, stmts) => Stmt::Iter(v, min, max,
|
|
stmts.into_iter().map(optimize_stmt).collect()
|
|
)
|
|
}
|
|
}
|
|
|
|
fn optimize_expr<'s>(e: Expr<'s>) -> Expr<'s> {
|
|
match e {
|
|
Expr::Unary(op, arg) => {
|
|
match optimize_expr(*arg) {
|
|
Expr::Number(z) => Expr::Number(apply_unary(op, z)),
|
|
e => Expr::Unary(op, Box::new(e)),
|
|
}
|
|
},
|
|
Expr::Binary(op, lhs, rhs) => {
|
|
match (optimize_expr(*lhs), optimize_expr(*rhs)) {
|
|
(Expr::Number(u), Expr::Number(v)) => Expr::Number(apply_binary(op, u, v)),
|
|
(u, v) => Expr::Binary(op, Box::new(u), Box::new(v))
|
|
}
|
|
},
|
|
Expr::FnCall(name, args) => {
|
|
let args: Vec<Expr<'s>> = args.into_iter().map(optimize_expr).collect();
|
|
if let Some((_, Type::Function(argc), Some(func))) = BUILTINS.with(|m| m.get(name).copied()) {
|
|
if argc as usize == args.len()
|
|
&& args.iter().all(|e| matches!(e, Expr::Number(_))) {
|
|
let ns: Vec<Complex> = args.into_iter().map(|a| match a { Expr::Number(n) => n, _ => unreachable!() }).collect();
|
|
return Expr::Number(func(&ns))
|
|
}
|
|
}
|
|
Expr::FnCall(name, args)
|
|
}
|
|
Expr::Name(name) => {
|
|
if let Some((_, Type::Number, Some(func))) = BUILTINS.with(|m| m.get(name).copied()) {
|
|
Expr::Number(func(&[]))
|
|
} else {
|
|
e
|
|
}
|
|
}
|
|
Expr::Number(_) => e,
|
|
}
|
|
}
|