use std::{
	cmp::Ordering,
	collections::HashMap,
	rc::Rc,
	sync::{atomic::AtomicBool, Arc},
};

use crate::{
	chunk::Instruction,
	exception::{throw, Exception, Result},
	lstring::LStr,
	parser::ast::{BinaryOp, UnaryOp},
	symbol::{
		Symbol, SYM_CALL_STACK_OVERFLOW, SYM_INTERRUPTED, SYM_NAME_ERROR, SYM_TYPE_ERROR,
	},
	value::{
		function::{FuncAttrs, Function, NativeFunc},
		Value,
	},
};

struct TryFrame {
	idx: usize,
	stack_len: usize,
}

struct CallFrame {
	func: Rc<Function>,
	locals: Vec<Value>,
	try_frames: Vec<TryFrame>,
	ip: usize,
	root: bool,
}

impl CallFrame {
	fn new(func: Rc<Function>, locals: Vec<Value>) -> Self {
		Self {
			func,
			locals,
			try_frames: Vec::new(),
			ip: 0,
			root: false,
		}
	}
}

pub struct Vm {
	stack: Vec<Value>,
	call_stack: Vec<CallFrame>,
	stack_max: usize,
	globals: HashMap<Symbol, Value>,
	interrupt: Arc<AtomicBool>,
}

pub fn binary_op(o: BinaryOp, a: Value, b: Value) -> Result<Value> {
	match o {
		BinaryOp::Add => a + b,
		BinaryOp::Sub => a - b,
		BinaryOp::Mul => a * b,
		BinaryOp::Div => a / b,
		BinaryOp::Mod => a.modulo(b),
		BinaryOp::IntDiv => a.int_div(b),
		BinaryOp::Pow => a.pow(b),
		BinaryOp::Shl => a << b,
		BinaryOp::Shr => a >> b,
		BinaryOp::BitAnd => a & b,
		BinaryOp::BitXor => a ^ b,
		BinaryOp::BitOr => a | b,
		BinaryOp::Eq => Ok(Value::Bool(a == b)),
		BinaryOp::Ne => Ok(Value::Bool(a != b)),
		BinaryOp::Gt => a.val_cmp(&b).map(|o| Value::Bool(o == Ordering::Greater)),
		BinaryOp::Ge => a.val_cmp(&b).map(|o| Value::Bool(o != Ordering::Less)),
		BinaryOp::Lt => a.val_cmp(&b).map(|o| Value::Bool(o == Ordering::Less)),
		BinaryOp::Le => a.val_cmp(&b).map(|o| Value::Bool(o != Ordering::Greater)),
		BinaryOp::Range => a.range(&b, false),
		BinaryOp::RangeIncl => a.range(&b, true),
		BinaryOp::Concat => a.concat(&b),
		BinaryOp::Append => a.append(b),
	}
}

pub fn unary_op(o: UnaryOp, a: Value) -> Result<Value> {
	match o {
		UnaryOp::Neg => -a,
		UnaryOp::Not => Ok(Value::Bool(!a.truthy())),
		UnaryOp::RangeEndless => a.range_endless(),
	}
}

enum CallOutcome {
	Call(Vec<Value>),
	Partial(Value),
}

fn get_call_outcome(args: Vec<Value>) -> Result<CallOutcome> {
	let f = &args[0];
	let Some(attrs) = f.func_attrs() else {
		throw!(*SYM_TYPE_ERROR, "cannot call non-function {f:#}")
	};
	let argc = args.len() - 1;
	match argc.cmp(&attrs.arity) {
		Ordering::Equal => Ok(CallOutcome::Call(args)),
		Ordering::Greater => throw!(*SYM_TYPE_ERROR, "too many arguments for function"),
		Ordering::Less => {
			let remaining = attrs.arity - argc;
			let f = f.clone();
			let nf = move |vm: &mut Vm, inner_args: Vec<Value>| {
				let mut ia = inner_args.into_iter();
				ia.next();
				let args: Vec<Value> = args.clone().into_iter().chain(ia).collect();
				vm.call_value(f.clone(), args)
			};
			let nf = NativeFunc {
				attrs: FuncAttrs {
					arity: remaining,
					name: None,
				},
				func: Box::new(nf),
			};
			Ok(CallOutcome::Partial(nf.into()))
		}
	}
}

impl Vm {
	pub fn new(stack_max: usize) -> Self {
		Self {
			stack: Vec::with_capacity(16),
			call_stack: Vec::with_capacity(16),
			globals: HashMap::with_capacity(16),
			stack_max,
			interrupt: Arc::new(AtomicBool::new(false)),
		}
	}

	pub fn get_interrupt(&self) -> Arc<AtomicBool> {
		self.interrupt.clone()
	}

	pub fn set_global(&mut self, name: Symbol, val: Value) {
		self.globals.insert(name, val);
	}

	pub fn set_global_name<'a, S>(&mut self, name: S, val: Value)
	where
		S: Into<&'a LStr>,
	{
		self.globals.insert(Symbol::get(name.into()), val);
	}

	pub fn get_global(&self, name: Symbol) -> Option<&Value> {
		self.globals.get(&name)
	}

	pub fn globals(&self) -> &HashMap<Symbol, Value> {
		&self.globals
	}

	pub fn call_value(&mut self, value: Value, args: Vec<Value>) -> Result<Value> {
		self.check_interrupt()?;
		match get_call_outcome(args)? {
			CallOutcome::Partial(v) => Ok(v),
			CallOutcome::Call(args) => match value {
				Value::Function(f) => self.run_function(f, args),
				Value::NativeFunc(f) => (f.func)(self, args),
				_ => unreachable!("already verified by calling get_call_type"),
			},
		}
	}

	pub fn run_function(&mut self, func: Rc<Function>, args: Vec<Value>) -> Result<Value> {
		if func.attrs.arity + 1 != args.len() {
			throw!(*SYM_TYPE_ERROR, "function call with wrong argument count");
		}
		let init_stack_len = self.stack.len();
		let mut frame = CallFrame::new(func, args);
		frame.root = true;

		loop {
			let instr = frame.func.chunk.instrs[frame.ip];
			frame.ip += 1;
			match self.run_instr(&mut frame, instr) {
				Ok(None) => (),
				Ok(Some(v)) => {
					self.stack.truncate(init_stack_len);
					return Ok(v)
				}
				Err(e) => {
					if let Err(e) = self.handle_exception(&mut frame, e) {
						self.stack.truncate(init_stack_len);
						return Err(e)
					}
				}
			}
		}
	}

	fn handle_exception(&mut self, frame: &mut CallFrame, exc: Exception) -> Result<()> {
		loop {
			while let Some(try_frame) = frame.try_frames.pop() {
				let table = &frame.func.chunk.try_tables[try_frame.idx];
				for catch in &table.catches {
					if catch.types.is_none()
						|| catch.types.as_ref().unwrap().contains(&exc.ty)
					{
						frame.ip = catch.addr;
						frame.locals.truncate(table.local_count);
						self.stack.truncate(try_frame.stack_len);
						self.stack.push(Value::Table(exc.to_table()));
						return Ok(())
					}
				}
			}
			if frame.root {
				return Err(exc)
			}
			*frame = self.call_stack.pop().expect("no root frame");
		}
	}

	#[inline]
	fn push(&mut self, v: Value) {
		self.stack.push(v);
	}

	#[inline]
	fn pop(&mut self) -> Value {
		self.stack.pop().expect("temporary stack underflow")
	}

	#[inline]
	fn pop_n(&mut self, n: usize) -> Vec<Value> {
		let res = self.stack.split_off(self.stack.len() - n);
		assert!(res.len() == n, "temporary stack underflow");
		res
	}

	fn check_interrupt(&mut self) -> Result<()> {
		if self
			.interrupt
			.fetch_and(false, std::sync::atomic::Ordering::Relaxed)
		{
			throw!(*SYM_INTERRUPTED)
		}
		Ok(())
	}

	fn run_instr(
		&mut self,
		frame: &mut CallFrame,
		instr: Instruction,
	) -> Result<Option<Value>> {
		use Instruction as I;

		match instr {
			// do nothing
			I::Nop => (),
			// [] -> [locals[n]]
			I::LoadLocal(n) => self.push(frame.locals[usize::from(n)].clone()),
			// [x] -> [], locals[n] = x
			I::StoreLocal(n) => frame.locals[usize::from(n)] = self.pop(),
			// [x] -> [], locals.push(x)
			I::NewLocal => frame.locals.push(self.pop()),
			// locals.pop_n(n)
			I::DropLocal(n) => frame.locals.truncate(frame.locals.len() - usize::from(n)),
			// [] -> [globals[s]]
			I::LoadGlobal(s) => {
				let sym = unsafe { s.to_symbol_unchecked() };
				let v = match self.globals.get(&sym) {
					Some(v) => v.clone(),
					None => throw!(*SYM_NAME_ERROR, "undefined global {}", sym.name()),
				};
				self.push(v);
			}
			// [x] -> [], globals[s] = x
			I::StoreGlobal(s) => {
				let sym = unsafe { s.to_symbol_unchecked() };
				let v = self.pop();
				self.globals.insert(sym, v);
			}
			I::CloseOver(n) => {
				let n = usize::from(n);
				let v = std::mem::replace(&mut frame.locals[n], Value::Nil);
				let v = v.to_cell();
				frame.locals[n] = v.clone();
				self.push(v);
			}
			I::Closure(n) => {
				let f = frame.func.chunk.consts[usize::from(n)].clone();
				let Value::Function(f) = f else {
					panic!("attempt to build closure from non-closure constant")
				};
				let mut f = f.as_ref().clone();

				let captured: Vec<_> = self
					.pop_n(f.state.len())
					.into_iter()
					.map(|v| {
						let Value::Cell(v) = v else {
							panic!("attempt to build closure from non-cell local");
						};
						v
					})
					.collect();

				f.state = captured.into_boxed_slice();
				self.push(f.into());
			}
			I::LoadUpvalue(n) => {
				let v = frame.func.state[usize::from(n)].clone();
				self.push(v.borrow().clone());
			}
			I::StoreUpvalue(n) => {
				let v = frame.func.state[usize::from(n)].clone();
				*v.borrow_mut() = self.pop();
			}
			I::ContinueUpvalue(n) => {
				let v = frame.func.state[usize::from(n)].clone();
				self.push(Value::Cell(v));
			}
			I::LoadClosedLocal(n) => {
				let Value::Cell(c) = &frame.locals[usize::from(n)] else {
					panic!("attempt to load from closed non-cell local");
				};
				self.push(c.borrow().clone());
			}
			I::StoreClosedLocal(n) => {
				let Value::Cell(c) = &frame.locals[usize::from(n)] else {
					panic!("attempt to store to closed non-cell local");
				};
				*c.borrow_mut() = self.pop();
			}
			// [] -> [consts[n]]
			I::Const(n) => self.push(frame.func.chunk.consts[usize::from(n)].clone()),
			// [] -> [nil]
			I::Nil => self.push(Value::Nil),
			// [] -> [b]
			I::Bool(b) => self.push(Value::Bool(b)),
			// [] -> [s]
			I::Symbol(s) => {
				let sym = unsafe { Symbol::from_id_unchecked(u32::from(s)) };
				self.push(Value::Symbol(sym));
			}
			// [] -> [n]
			I::Int(n) => self.push(Value::Int(i64::from(n))),
			// [x] -> [x,x]
			I::Dup => self.push(self.stack[self.stack.len() - 1].clone()),
			// [x,y] -> [x,y,x,y]
			I::DupTwo => {
				self.push(self.stack[self.stack.len() - 2].clone());
				self.push(self.stack[self.stack.len() - 2].clone());
			}
			// [a0,a1...an] -> []
			I::Drop(n) => {
				for _ in 0..u32::from(n) {
					self.pop();
				}
			}
			// [x,y] -> [y,x]
			I::Swap => {
				let len = self.stack.len();
				self.stack.swap(len - 1, len - 2);
			}
			// [x,y] -> [y op x]
			I::BinaryOp(op) => {
				let b = self.pop();
				let a = self.pop();
				self.push(binary_op(op, a, b)?);
			}
			// [x] -> [op x]
			I::UnaryOp(op) => {
				let a = self.pop();
				self.push(unary_op(op, a)?);
			}
			// [a0,a1...an] -.> [[a0,a1...an]]
			I::NewList(n) => {
				let list = self.pop_n(n as usize);
				self.push(list.into());
			}
			// [l,a0,a1...an] -.> [l ++ [a0,a1...an]]
			I::GrowList(n) => {
				let ext = self.pop_n(n as usize);
				let list = self.pop();
				let Value::List(list) = list else {
					panic!("not a list")
				};
				list.borrow_mut().extend(ext);
				self.push(Value::List(list));
			}
			// [k0,v0...kn,vn] -.> [{k0=v0...kn=vn}]
			I::NewTable(n) => {
				let mut table = HashMap::new();
				for _ in 0..n {
					let v = self.pop();
					let k = self.pop();
					table.insert(k.try_into()?, v);
				}
				self.push(table.into());
			}
			// [t,k0,v0...kn,vn] -> [t ++ {k0=v0...kn=vn}]
			I::GrowTable(n) => {
				let mut ext = self.pop_n(2 * n as usize);
				let table = self.pop();
				let Value::Table(table) = table else {
					panic!("not a table")
				};
				let mut table_ref = table.borrow_mut();
				for _ in 0..n {
					// can't panic: pop_n checked that ext would have len 2*n
					let v = ext.pop().unwrap();
					let k = ext.pop().unwrap();
					table_ref.insert(k.try_into()?, v);
				}
				drop(table_ref);
				self.push(Value::Table(table));
			}
			// [ct, idx] -> [ct!idx]
			I::Index => {
				let idx = self.pop();
				let ct = self.pop();
				self.push(ct.index(idx)?);
			}
			// [ct, idx, v] -> [v], ct!idx = v
			I::StoreIndex => {
				let v = self.pop();
				let idx = self.pop();
				let ct = self.pop();
				ct.store_index(self, idx, v.clone())?;
				self.push(v);
			}
			// ip = n
			I::Jump(n) => {
				self.check_interrupt()?;
				frame.ip = usize::from(n);
			}
			// [v] ->, [], if v then ip = n
			I::JumpTrue(n) => {
				if self.pop().truthy() {
					self.check_interrupt()?;
					frame.ip = usize::from(n);
				}
			}
			// [v] ->, [], if not v then ip = n
			I::JumpFalse(n) => {
				if !self.pop().truthy() {
					self.check_interrupt()?;
					frame.ip = usize::from(n);
				}
			}
			// [v] -> [iter(v)]
			I::IterBegin => {
				let iter = self.pop().to_iter_function()?;
				self.push(iter);
			}
			// [i,cell(v)] -> [i,v]
			// [i,nil] -> [], ip = n
			I::IterTest(n) => {
				if let Some(v) = self.pop().iter_unpack() {
					self.push(v);
				} else {
					self.pop();
					frame.ip = usize::from(n);
				}
			}
			// try_frames.push(t, stack.len())
			I::BeginTry(t) => {
				let tryframe = TryFrame {
					idx: usize::from(t),
					stack_len: self.stack.len(),
				};
				frame.try_frames.push(tryframe);
			}
			// try_frames.pop()
			I::EndTry => {
				frame.try_frames.pop().expect("no try to pop");
			}
			// [f,a0,a1...an] -> [f(a0,a1...an)]
			I::Call(n) => {
				let n = usize::from(n);

				let args = self.pop_n(n + 1);

				let args = match get_call_outcome(args)? {
					CallOutcome::Call(args) => args,
					CallOutcome::Partial(v) => {
						self.push(v);
						return Ok(None)
					}
				};

				if let Value::NativeFunc(nf) = &args[0] {
					let nf = nf.clone();

					// safety: frame is restored immediately
					// after function call ends
					// ~25% performance improvement in
					// code heavy on native function calls
					unsafe {
						let f = std::ptr::read(frame);
						self.call_stack.push(f);
					}

					let res = (nf.func)(self, args);

					// safety: frame was referencing invalid memory due to
					// previous unsafe block, write will fix that
					unsafe {
						let f = self.call_stack.pop().expect("no frame to pop");
						std::ptr::write(frame, f);
					}

					// make sure we restored the value of frame
					// before propagating exceptions
					let res = res?;

					self.push(res);
				} else if let Value::Function(func) = &args[0] {
					if self.call_stack.len() + 1 >= self.stack_max {
						throw!(*SYM_CALL_STACK_OVERFLOW, "call stack overflow")
					}

					let new_frame = CallFrame::new(func.clone(), args);
					let old_frame = std::mem::replace(frame, new_frame);
					self.call_stack.push(old_frame);
				} else {
					unreachable!("already verified by calling get_call_type");
				}
			}
			// [v] -> [], return v
			I::Return if frame.root => return Ok(Some(self.pop())),
			// [v] -> [], return v
			I::Return => {
				self.check_interrupt()?;
				*frame = self.call_stack.pop().expect("no root frame");
			}
		}

		Ok(None)
	}
}

#[macro_export]
macro_rules! vmcall {
	($vm:expr; $func:expr, $($arg:expr),*) => {{
		let f = $func;
		$vm.call_value(f.clone(), vec![f, $($arg),*])
	}};
	($vm:expr; $func:expr) => {{
		let f = $func;
		$vm.call_value(f.clone(), vec![f])
	}};
}

#[macro_export]
macro_rules! vmcalliter {
	($($input:tt)*) => {
		$crate::vmcall!($($input)*).map(|v| v.iter_unpack())
	}
}