use core::fmt;
use std::num::{ParseFloatError, ParseIntError};

use crate::lstring::{LStr, LString};

use super::Span;

#[derive(Clone, Debug)]
pub struct ParserError {
	pub span: Span,
	pub msg: String,
}

impl std::error::Error for ParserError {}

impl fmt::Display for ParserError {
	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
		write!(f, "{} | {}", self.span, self.msg)
	}
}

pub trait SpanParserError {
	type Output;
	fn span_err(self, span: Span) -> Self::Output;
}

impl<T, E: std::error::Error> SpanParserError for Result<T, E> {
	type Output = Result<T, ParserError>;
	fn span_err(self, span: Span) -> Self::Output {
		self.map_err(|e| ParserError {
			span,
			msg: e.to_string(),
		})
	}
}

#[derive(Clone, Copy, Debug)]
pub enum StrEscapeError {
	Eof,
	HexEof,
	UnicodeEof,
	Invalid(char),
	InvalidHex(char),
	MissingBrace,
	InvalidCodepoint(u32),
	CodepointTooLarge,
}

impl std::error::Error for StrEscapeError {}

impl fmt::Display for StrEscapeError {
	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
		match self {
			StrEscapeError::Eof => write!(f, "EOF in string escape"),
			StrEscapeError::HexEof => write!(f, "EOF in string escape \\x"),
			StrEscapeError::UnicodeEof => write!(f, "EOF in string escape \\u"),
			StrEscapeError::Invalid(c) => write!(f, "invalid string escape \\{c}"),
			StrEscapeError::InvalidHex(x) => {
				write!(f, "invalid hex digit '{x}' in string escape")
			}
			StrEscapeError::MissingBrace => {
				write!(f, "missing brace after string escape \\u")
			}
			StrEscapeError::InvalidCodepoint(n) => {
				write!(f, "invalid codepoint in string escape: U+{n:0>4x})")
			}
			StrEscapeError::CodepointTooLarge => {
				write!(f, "codepoint in string escape too large")
			}
		}
	}
}

pub fn parse_str_escapes(src: &str) -> Result<LString, StrEscapeError> {
	let mut s = LString::with_capacity(src.len());
	let mut chars = src.chars();

	while let Some(c) = chars.next() {
		if c != '\\' {
			s.push_char(c);
			continue
		}
		let c = chars.next().ok_or(StrEscapeError::Eof)?;
		match c {
			'"' | '\'' | '\\' => s.push_char(c),
			'0' => s.push_char('\0'),
			'a' => s.push_char('\x07'),
			'b' => s.push_char('\x08'),
			't' => s.push_char('\t'),
			'n' => s.push_char('\n'),
			'v' => s.push_char('\x0b'),
			'f' => s.push_char('\x0c'),
			'r' => s.push_char('\r'),
			'e' => s.push_char('\x1b'),
			'x' => {
				let c = chars.next().ok_or(StrEscapeError::HexEof)?;
				let n1 = c.to_digit(16).ok_or(StrEscapeError::InvalidHex(c))?;
				let c = chars.next().ok_or(StrEscapeError::HexEof)?;
				let n2 = c.to_digit(16).ok_or(StrEscapeError::InvalidHex(c))?;
				s.push_byte((n1 * 16 + n2) as u8);
			}
			'u' => {
				let Some('{') = chars.next() else {
					return Err(StrEscapeError::MissingBrace)
				};
				let mut n = 0_u32;
				loop {
					let Some(c) = chars.next() else {
						return Err(StrEscapeError::UnicodeEof)
					};
					if c == '}' {
						break
					}
					if n > 0x10ffff {
						return Err(StrEscapeError::CodepointTooLarge)
					}
					n = n * 16 + c.to_digit(16).ok_or(StrEscapeError::InvalidHex(c))?;
				}
				let ch = char::from_u32(n).ok_or(StrEscapeError::InvalidCodepoint(n))?;
				s.push_char(ch);
			}
			c => return Err(StrEscapeError::Invalid(c)),
		}
	}

	Ok(s)
}

pub fn parse_float<'a, S: Into<&'a LStr>>(f: S) -> Result<f64, ParseFloatError> {
	let mut s = String::new();
	for c in f.into().chars() {
		if c != '_' {
			s.push(c);
		}
	}
	s.parse()
}

pub fn parse_int<'a, S: Into<&'a LStr>>(f: S, radix: u32) -> Result<i64, ParseIntError> {
	let mut s = String::new();
	for c in f.into().chars() {
		if c != '_' {
			s.push(c);
		}
	}
	i64::from_str_radix(&s, radix)
}

pub fn parse_int_literal<'a, S: Into<&'a LStr>>(f: S) -> Result<i64, ParseIntError> {
	let f = f.into();
	match f.chars().nth(2) {
		Some('x') => parse_int(&f[2..], 16),
		Some('o') => parse_int(&f[2..], 8),
		Some('s') => parse_int(&f[2..], 6),
		Some('b') => parse_int(&f[2..], 2),
		_ => parse_int(f, 10),
	}
}

pub fn to_lstring_radix(n: i64, radix: u32, upper: bool) -> LString {
	let mut result = vec![];
	let mut begin = 0;

	let mut x;
	if n < 0 {
		result.push('-' as u32 as u8);
		begin = 1;
		x = (-n) as u64;
	} else {
		x = n as u64;
	}

	loop {
		let m = x % (radix as u64);
		x /= radix as u64;

		let mut c = char::from_digit(m as u32, radix).unwrap();
		if upper {
			c.make_ascii_uppercase();
		}
		result.push(c as u8);
		if x == 0 {
			break
		}
	}
	result[begin..].reverse();
	LString::from(result)
}