use std::{ cmp::Ordering, collections::HashMap, rc::Rc, sync::{atomic::AtomicBool, Arc}, }; use crate::{ chunk::Instruction, exception::{throw, Exception, Result}, lstring::LStr, parser::ast::{BinaryOp, UnaryOp}, symbol::{Symbol, SYM_CALL_STACK_OVERFLOW, SYM_INTERRUPTED, SYM_NAME_ERROR, SYM_TYPE_ERROR}, value::{ function::{FuncAttrs, Function, NativeFunc}, Value, }, }; struct TryFrame { idx: usize, stack_len: usize, } struct CallFrame { func: Rc, locals: Vec, try_frames: Vec, ip: usize, root: bool, } impl CallFrame { fn new(func: Rc, locals: Vec) -> Self { Self { func, locals, try_frames: Vec::new(), ip: 0, root: false, } } } pub struct Vm { stack: Vec, call_stack: Vec, stack_max: usize, globals: HashMap, interrupt: Arc, } pub fn binary_op(o: BinaryOp, a: Value, b: Value) -> Result { match o { BinaryOp::Add => a + b, BinaryOp::Sub => a - b, BinaryOp::Mul => a * b, BinaryOp::Div => a / b, BinaryOp::Mod => a.modulo(b), BinaryOp::IntDiv => a.int_div(b), BinaryOp::Pow => a.pow(b), BinaryOp::Shl => a << b, BinaryOp::Shr => a >> b, BinaryOp::BitAnd => a & b, BinaryOp::BitXor => a ^ b, BinaryOp::BitOr => a | b, BinaryOp::Eq => Ok(Value::Bool(a == b)), BinaryOp::Ne => Ok(Value::Bool(a != b)), BinaryOp::Gt => a.val_cmp(&b).map(|o| Value::Bool(o == Ordering::Greater)), BinaryOp::Ge => a.val_cmp(&b).map(|o| Value::Bool(o != Ordering::Less)), BinaryOp::Lt => a.val_cmp(&b).map(|o| Value::Bool(o == Ordering::Less)), BinaryOp::Le => a.val_cmp(&b).map(|o| Value::Bool(o != Ordering::Greater)), BinaryOp::Range => a.range(&b, false), BinaryOp::RangeIncl => a.range(&b, true), BinaryOp::Concat => a.concat(&b), BinaryOp::Append => a.append(b), } } pub fn unary_op(o: UnaryOp, a: Value) -> Result { match o { UnaryOp::Neg => -a, UnaryOp::Not => Ok(Value::Bool(!a.truthy())), UnaryOp::RangeEndless => a.range_endless(), } } enum CallOutcome { Call(Vec), Partial(Value), } fn get_call_outcome(args: Vec) -> Result { let f = &args[0]; let Some(attrs) = f.func_attrs() else { throw!(*SYM_TYPE_ERROR, "cannot call non-function {f:#}") }; let argc = args.len() - 1; match argc.cmp(&attrs.arity) { Ordering::Equal => Ok(CallOutcome::Call(args)), Ordering::Greater => throw!(*SYM_TYPE_ERROR, "too many arguments for function"), Ordering::Less => { let remaining = attrs.arity - argc; let f = f.clone(); let nf = move |vm: &mut Vm, inner_args: Vec| { let mut ia = inner_args.into_iter(); ia.next(); let args: Vec = args.clone().into_iter().chain(ia).collect(); vm.call_value(f.clone(), args) }; let nf = NativeFunc { attrs: FuncAttrs { arity: remaining, name: None, }, func: Box::new(nf), }; Ok(CallOutcome::Partial(nf.into())) } } } impl Vm { pub fn new(stack_max: usize) -> Self { Self { stack: Vec::with_capacity(16), call_stack: Vec::with_capacity(16), globals: HashMap::with_capacity(16), stack_max, interrupt: Arc::new(AtomicBool::new(false)), } } pub fn get_interrupt(&self) -> Arc { self.interrupt.clone() } pub fn set_global(&mut self, name: Symbol, val: Value) { self.globals.insert(name, val); } pub fn set_global_name<'a, S>(&mut self, name: S, val: Value) where S: Into<&'a LStr>, { self.globals.insert(Symbol::get(name.into()), val); } pub fn get_global(&self, name: Symbol) -> Option<&Value> { self.globals.get(&name) } pub fn globals(&self) -> &HashMap { &self.globals } pub fn call_value(&mut self, value: Value, args: Vec) -> Result { self.check_interrupt()?; match get_call_outcome(args)? { CallOutcome::Partial(v) => Ok(v), CallOutcome::Call(args) => match value { Value::Function(f) => self.run_function(f, args), Value::NativeFunc(f) => (f.func)(self, args), _ => unreachable!("already verified by calling get_call_type"), }, } } pub fn run_function(&mut self, func: Rc, args: Vec) -> Result { if func.attrs.arity + 1 != args.len() { throw!(*SYM_TYPE_ERROR, "function call with wrong argument count"); } let init_stack_len = self.stack.len(); let mut frame = CallFrame::new(func, args); frame.root = true; loop { let instr = frame.func.chunk.instrs[frame.ip]; frame.ip += 1; match self.run_instr(&mut frame, instr) { Ok(None) => (), Ok(Some(v)) => { self.stack.truncate(init_stack_len); return Ok(v) } Err(e) => { if let Err(e) = self.handle_exception(&mut frame, e) { self.stack.truncate(init_stack_len); return Err(e) } } } } } fn handle_exception(&mut self, frame: &mut CallFrame, exc: Exception) -> Result<()> { loop { while let Some(try_frame) = frame.try_frames.pop() { let table = &frame.func.chunk.try_tables[try_frame.idx]; for catch in &table.catches { if catch.types.is_none() || catch.types.as_ref().unwrap().contains(&exc.ty) { frame.ip = catch.addr; frame.locals.truncate(table.local_count); self.stack.truncate(try_frame.stack_len); self.stack.push(Value::Table(exc.to_table())); return Ok(()) } } } if frame.root { return Err(exc) } *frame = self.call_stack.pop().expect("no root frame"); } } #[inline] fn push(&mut self, v: Value) { self.stack.push(v); } #[inline] fn pop(&mut self) -> Value { self.stack.pop().expect("temporary stack underflow") } #[inline] fn pop_n(&mut self, n: usize) -> Vec { let res = self.stack.split_off(self.stack.len() - n); assert!(res.len() == n, "temporary stack underflow"); res } fn check_interrupt(&mut self) -> Result<()> { if self .interrupt .fetch_and(false, std::sync::atomic::Ordering::Relaxed) { throw!(*SYM_INTERRUPTED) } Ok(()) } fn run_instr(&mut self, frame: &mut CallFrame, instr: Instruction) -> Result> { use Instruction as I; match instr { // do nothing I::Nop => (), // [] -> [locals[n]] I::LoadLocal(n) => self.push(frame.locals[usize::from(n)].clone()), // [x] -> [], locals[n] = x I::StoreLocal(n) => frame.locals[usize::from(n)] = self.pop(), // [x] -> [], locals.push(x) I::NewLocal => frame.locals.push(self.pop()), // locals.pop_n(n) I::DropLocal(n) => frame.locals.truncate(frame.locals.len() - usize::from(n)), // [] -> [globals[s]] I::LoadGlobal(s) => { let sym = unsafe { s.to_symbol_unchecked() }; let v = match self.globals.get(&sym) { Some(v) => v.clone(), None => throw!(*SYM_NAME_ERROR, "undefined global {}", sym.name()), }; self.push(v); } // [x] -> [], globals[s] = x I::StoreGlobal(s) => { let sym = unsafe { s.to_symbol_unchecked() }; let v = self.pop(); self.globals.insert(sym, v); } I::CloseOver(n) => { let n = usize::from(n); let v = std::mem::replace(&mut frame.locals[n], Value::Nil); let v = v.to_cell(); frame.locals[n] = v.clone(); self.push(v); } I::Closure(n) => { let f = frame.func.chunk.consts[usize::from(n)].clone(); let Value::Function(f) = f else { panic!("attempt to build closure from non-closure constant") }; let mut f = f.as_ref().clone(); let captured: Vec<_> = self .pop_n(f.state.len()) .into_iter() .map(|v| { let Value::Cell(v) = v else { panic!("attempt to build closure from non-cell local"); }; v }) .collect(); f.state = captured.into_boxed_slice(); self.push(f.into()); } I::LoadUpvalue(n) => { let v = frame.func.state[usize::from(n)].clone(); self.push(v.borrow().clone()); } I::StoreUpvalue(n) => { let v = frame.func.state[usize::from(n)].clone(); *v.borrow_mut() = self.pop(); } I::ContinueUpvalue(n) => { let v = frame.func.state[usize::from(n)].clone(); self.push(Value::Cell(v)); } I::LoadClosedLocal(n) => { let Value::Cell(c) = &frame.locals[usize::from(n)] else { panic!("attempt to load from closed non-cell local"); }; self.push(c.borrow().clone()); } I::StoreClosedLocal(n) => { let Value::Cell(c) = &frame.locals[usize::from(n)] else { panic!("attempt to store to closed non-cell local"); }; *c.borrow_mut() = self.pop(); } // [] -> [consts[n]] I::Const(n) => self.push(frame.func.chunk.consts[usize::from(n)].clone()), // [] -> [nil] I::Nil => self.push(Value::Nil), // [] -> [b] I::Bool(b) => self.push(Value::Bool(b)), // [] -> [s] I::Symbol(s) => { let sym = unsafe { Symbol::from_id_unchecked(u32::from(s)) }; self.push(Value::Symbol(sym)); } // [] -> [n] I::Int(n) => self.push(Value::Int(i64::from(n))), // [x] -> [x,x] I::Dup => self.push(self.stack[self.stack.len() - 1].clone()), // [x,y] -> [x,y,x,y] I::DupTwo => { self.push(self.stack[self.stack.len() - 2].clone()); self.push(self.stack[self.stack.len() - 2].clone()); } // [a0,a1...an] -> [] I::Drop(n) => { for _ in 0..u32::from(n) { self.pop(); } } // [x,y] -> [y,x] I::Swap => { let len = self.stack.len(); self.stack.swap(len - 1, len - 2); } // [x,y] -> [y op x] I::BinaryOp(op) => { let b = self.pop(); let a = self.pop(); self.push(binary_op(op, a, b)?); } // [x] -> [op x] I::UnaryOp(op) => { let a = self.pop(); self.push(unary_op(op, a)?); } // [a0,a1...an] -.> [[a0,a1...an]] I::NewList(n) => { let list = self.pop_n(n as usize); self.push(list.into()); } // [l,a0,a1...an] -.> [l ++ [a0,a1...an]] I::GrowList(n) => { let ext = self.pop_n(n as usize); let list = self.pop(); let Value::List(list) = list else { panic!("not a list") }; list.borrow_mut().extend(ext); self.push(Value::List(list)); } // [k0,v0...kn,vn] -.> [{k0=v0...kn=vn}] I::NewTable(n) => { let mut table = HashMap::new(); for _ in 0..n { let v = self.pop(); let k = self.pop(); table.insert(k.try_into()?, v); } self.push(table.into()); } // [t,k0,v0...kn,vn] -> [t ++ {k0=v0...kn=vn}] I::GrowTable(n) => { let mut ext = self.pop_n(2 * n as usize); let table = self.pop(); let Value::Table(table) = table else { panic!("not a table") }; let mut table_ref = table.borrow_mut(); for _ in 0..n { // can't panic: pop_n checked that ext would have len 2*n let v = ext.pop().unwrap(); let k = ext.pop().unwrap(); table_ref.insert(k.try_into()?, v); } drop(table_ref); self.push(Value::Table(table)); } // [ct, idx] -> [ct!idx] I::Index => { let idx = self.pop(); let ct = self.pop(); self.push(ct.index(idx)?); } // [ct, idx, v] -> [v], ct!idx = v I::StoreIndex => { let v = self.pop(); let idx = self.pop(); let ct = self.pop(); ct.store_index(self, idx, v.clone())?; self.push(v); } // ip = n I::Jump(n) => { self.check_interrupt()?; frame.ip = usize::from(n); } // [v] ->, [], if v then ip = n I::JumpTrue(n) => { if self.pop().truthy() { self.check_interrupt()?; frame.ip = usize::from(n); } } // [v] ->, [], if not v then ip = n I::JumpFalse(n) => { if !self.pop().truthy() { self.check_interrupt()?; frame.ip = usize::from(n); } } // [v] -> [iter(v)] I::IterBegin => { let iter = self.pop().to_iter_function()?; self.push(iter); } // [i,cell(v)] -> [i,v] // [i,nil] -> [], ip = n I::IterTest(n) => { if let Some(v) = self.pop().iter_unpack() { self.push(v); } else { self.pop(); frame.ip = usize::from(n); } } // try_frames.push(t, stack.len()) I::BeginTry(t) => { let tryframe = TryFrame { idx: usize::from(t), stack_len: self.stack.len(), }; frame.try_frames.push(tryframe); } // try_frames.pop() I::EndTry => { frame.try_frames.pop().expect("no try to pop"); } // [f,a0,a1...an] -> [f(a0,a1...an)] I::Call(n) => { let n = usize::from(n); let args = self.pop_n(n + 1); let args = match get_call_outcome(args)? { CallOutcome::Call(args) => args, CallOutcome::Partial(v) => { self.push(v); return Ok(None) } }; if let Value::NativeFunc(nf) = &args[0] { let nf = nf.clone(); // safety: frame is restored immediately // after function call ends // ~25% performance improvement in // code heavy on native function calls unsafe { let f = std::ptr::read(frame); self.call_stack.push(f); } let res = (nf.func)(self, args); // safety: frame was referencing invalid memory due to // previous unsafe block, write will fix that unsafe { let f = self.call_stack.pop().expect("no frame to pop"); std::ptr::write(frame, f); } // make sure we restored the value of frame // before propagating exceptions let res = res?; self.push(res); } else if let Value::Function(func) = &args[0] { if self.call_stack.len() + 1 >= self.stack_max { throw!(*SYM_CALL_STACK_OVERFLOW, "call stack overflow") } let new_frame = CallFrame::new(func.clone(), args); let old_frame = std::mem::replace(frame, new_frame); self.call_stack.push(old_frame); } else { unreachable!("already verified by calling get_call_type"); } } // [v] -> [], return v I::Return if frame.root => return Ok(Some(self.pop())), // [v] -> [], return v I::Return => { self.check_interrupt()?; *frame = self.call_stack.pop().expect("no root frame"); } } Ok(None) } } #[macro_export] macro_rules! vmcall { ($vm:expr; $func:expr, $($arg:expr),*) => {{ let f = $func; $vm.call_value(f.clone(), vec![f, $($arg),*]) }}; ($vm:expr; $func:expr) => {{ let f = $func; $vm.call_value(f.clone(), vec![f]) }}; } #[macro_export] macro_rules! vmcalliter { ($($input:tt)*) => { $crate::vmcall!($($input)*).map(|v| v.iter_unpack()) } }