This commit is contained in:
trimill 2024-08-30 10:02:02 -04:00
parent 4fcd317390
commit 8702cddf44
11 changed files with 820 additions and 518 deletions

852
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -5,4 +5,4 @@ members = [
"cxgraph-desktop",
"cxgraph-web",
]
resolver = "2"

View file

@ -3,6 +3,15 @@
cxgraph is a complex function graphing tool built around WebGPU available
[on the web](https://cx.trimill.xyz/) or (slightly) for desktop.
## building (web)
install `wasm-pack` through your package manager or with `cargo install wasm-pack`.
```sh
cd cxgraph-web
wasm-pack build --no-typescript --no-pack --target web
```
## documentation
- [language](docs/language.md)
- [web interface](docs/web.md)

View file

@ -4,9 +4,9 @@ import init, * as cxgraph from "./pkg/cxgraph_web.js";
await init();
let graphView = {
xoff: 0,
yoff: 0,
scale: 3,
xoff: 0.00001,
yoff: 0.00001,
scale: 2.99932736,
res_mult: 1,
varNames: [],
};
@ -217,13 +217,23 @@ let charMap = {
"Psi": "\u03a8",
"Omega": "\u03a9",
"vartheta": "\u03d1",
"0": "\u2080",
"1": "\u2081",
"2": "\u2082",
"3": "\u2083",
"4": "\u2084",
"5": "\u2085",
"6": "\u2086",
"7": "\u2087",
"8": "\u2088",
"9": "\u2089",
};
let specialChars = new RegExp(
`\\\\(${Object.keys(charMap).join("|")})`
);
console.log(specialChars);
source_text.addEventListener("input", () => {
source_text.addEventListener("input", (event) => {
if(event.isComposing) return;
let e = source_text.selectionEnd;
let amnt = 0;
source_text.value = source_text.value.replace(

View file

@ -13,6 +13,7 @@ lalrpop-util = { version = "0.20.0", features = ["lexer", "unicode"] }
num-complex = "0.4"
wgpu = "0.16"
raw-window-handle = "0.5"
unicode-xid = "0.2"
[build-dependencies]
lalrpop = "0.20.0"

View file

@ -5,6 +5,7 @@ use num_complex::Complex64 as Complex;
#[derive(Clone, Copy, Debug)]
pub enum BinaryOp {
Add, Sub, Mul, Div, Pow,
Gt, Lt, Ge, Le, Eq, Ne,
}
#[derive(Clone, Copy, Debug)]
@ -21,6 +22,7 @@ pub enum ExpressionType<'a> {
Unary(UnaryOp),
FnCall(&'a str),
Store(&'a str),
If,
Sum { countvar: &'a str, min: i32, max: i32 },
Prod { countvar: &'a str, min: i32, max: i32 },
Iter { itervar: &'a str, count: i32 },
@ -61,25 +63,31 @@ impl<'a> Expression<'a> {
Self { ty: ExpressionType::Store(name), children: vec![expr] }
}
pub fn new_sum(countvar: &'a str, min: i32, max: i32, body: Vec<Self>) -> Self {
pub fn new_if(cond: Self, t: Self, f: Self) -> Self {
Self {
ty: ExpressionType::If,
children: vec![cond, t, f],
}
}
pub fn new_sum(countvar: &'a str, min: i32, max: i32, body: Self) -> Self {
Self {
ty: ExpressionType::Sum { countvar, min, max },
children: body,
children: vec![body],
}
}
pub fn new_prod(accvar: &'a str, min: i32, max: i32, body: Vec<Self>) -> Self {
pub fn new_prod(accvar: &'a str, min: i32, max: i32, body: Self) -> Self {
Self {
ty: ExpressionType::Prod { countvar: accvar, min, max },
children: body,
children: vec![body],
}
}
pub fn new_iter(itervar: &'a str, count: i32, init: Self, mut body: Vec<Self>) -> Self {
body.push(init);
pub fn new_iter(itervar: &'a str, count: i32, init: Self, body: Self) -> Self {
Self {
ty: ExpressionType::Iter { itervar, count },
children: body,
children: vec![init, body],
}
}
}
@ -99,6 +107,7 @@ fn display_expr(w: &mut impl fmt::Write, expr: &Expression, depth: usize) -> fmt
ExpressionType::Unary(op) => write!(w, "{:indent$}OP {op:?}", "", indent=indent)?,
ExpressionType::FnCall(f) => write!(w, "{:indent$}CALL {f}", "", indent=indent)?,
ExpressionType::Store(n) => write!(w, "{:indent$}STORE {n}", "", indent=indent)?,
ExpressionType::If => write!(w, "{:indent$}IF", "", indent=indent)?,
ExpressionType::Sum { countvar, min, max } => write!(w, "{:indent$}SUM {countvar} {min} {max}", "", indent=indent)?,
ExpressionType::Prod { countvar, min, max } => write!(w, "{:indent$}PROD {countvar} {min} {max}", "", indent=indent)?,
ExpressionType::Iter { itervar, count } => write!(w, "{:indent$}ITER {itervar} {count}", "", indent=indent)?,

View file

@ -10,18 +10,13 @@ thread_local! {
m.insert("recip", ("c_recip", 1));
m.insert("conj", ("c_conj", 1));
m.insert("ifgt", ("c_ifgt", 4));
m.insert("iflt", ("c_iflt", 4));
m.insert("ifge", ("c_ifge", 4));
m.insert("ifle", ("c_ifle", 4));
m.insert("ifeq", ("c_ifeq", 4));
m.insert("ifne", ("c_ifne", 4));
m.insert("ifnan", ("c_ifnan", 3));
m.insert("re", ("c_re", 1));
m.insert("im", ("c_im", 1));
m.insert("signre", ("c_signre", 1));
m.insert("signim", ("c_signim", 1));
m.insert("absre", ("c_absre", 1));
m.insert("absim", ("c_absim", 1));
m.insert("isnan", ("c_isnan", 1));
m.insert("abs_sq", ("c_abs_sq", 1));
m.insert("abs", ("c_abs", 1));
m.insert("arg", ("c_arg", 1));
@ -48,7 +43,6 @@ thread_local! {
m.insert("sinh", ("c_sinh", 1));
m.insert("cosh", ("c_cosh", 1));
m.insert("tanh", ("c_tanh", 1));
m.insert("asin", ("c_asin", 1));
m.insert("acos", ("c_acos", 1));
m.insert("atan", ("c_atan", 1));
@ -64,6 +58,8 @@ thread_local! {
m.insert("log\u{0393}", ("c_loggamma", 1));
m.insert("digamma", ("c_digamma", 1));
m.insert("\u{03C8}", ("c_digamma", 1));
m.insert("lambertw", ("c_lambertw", 1));
m.insert("lambertwbr", ("c_lambertwbr", 2));
m
};

View file

@ -29,6 +29,11 @@ fn format_char(buf: &mut String, c: char) {
match c {
'_' => buf.push_str("u_"),
'\'' => buf.push_str("p_"),
'\u{2080}'..='\u{2089}' => {
buf.push('s');
buf.push((c as u32 - 0x2080 + '0' as u32).try_into().expect("invalid codepoint"));
buf.push('_');
}
c => buf.push(c),
}
}
@ -190,6 +195,12 @@ impl<'w, 'i, W: fmt::Write> Compiler<'w, 'i, W> {
BinaryOp::Mul => writeln!(self.buf, "var {name} = c_mul({a}, {b});")?,
BinaryOp::Div => writeln!(self.buf, "var {name} = c_div({a}, {b});")?,
BinaryOp::Pow => writeln!(self.buf, "var {name} = c_pow({a}, {b});")?,
BinaryOp::Gt => writeln!(self.buf, "var {name} = select(C_ZERO, C_ONE, {a}.x > {b}.x);")?,
BinaryOp::Lt => writeln!(self.buf, "var {name} = select(C_ZERO, C_ONE, {a}.x < {b}.x);")?,
BinaryOp::Ge => writeln!(self.buf, "var {name} = select(C_ZERO, C_ONE, {a}.x >= {b}.x);")?,
BinaryOp::Le => writeln!(self.buf, "var {name} = select(C_ZERO, C_ONE, {a}.x <= {b}.x);")?,
BinaryOp::Eq => writeln!(self.buf, "var {name} = select(C_ZERO, C_ONE, ({a}.x == {b}.x) && ({a}.y == {b}.y));")?,
BinaryOp::Ne => writeln!(self.buf, "var {name} = select(C_ZERO, C_ONE, ({a}.x != {b}.x) || ({a}.y != {b}.y));")?,
}
Ok(name)
@ -226,6 +237,19 @@ impl<'w, 'i, W: fmt::Write> Compiler<'w, 'i, W> {
Ok(name)
},
ExpressionType::If => {
let cond = self.compile_expr(local, &expr.children[0])?;
let result = local.next_tmp();
writeln!(self.buf, "var {result}: vec2f;")?;
writeln!(self.buf, "if {cond}.x > 0.0 {{")?;
let t = self.compile_expr(local, &expr.children[1])?;
writeln!(self.buf, "{result} = {t};")?;
writeln!(self.buf, "}} else {{")?;
let f = self.compile_expr(local, &expr.children[2])?;
writeln!(self.buf, "{result} = {f};")?;
writeln!(self.buf, "}}")?;
Ok(result)
},
ExpressionType::Sum { countvar, min, max }
| ExpressionType::Prod { countvar, min, max } => {
let acc = local.next_tmp();
@ -239,19 +263,17 @@ impl<'w, 'i, W: fmt::Write> Compiler<'w, 'i, W> {
writeln!(self.buf, "var {} = vec2f(f32({ivar}), 0.0);", format_local(countvar))?;
let mut loop_local = local.clone();
loop_local.local_vars.insert(countvar);
let mut last = String::new();
for child in &expr.children {
last = self.compile_expr(&mut loop_local, child)?;
}
let body = self.compile_expr(&mut loop_local, &expr.children[0])?;
if matches!(expr.ty, ExpressionType::Sum { .. }) {
writeln!(self.buf, "{acc} = {acc} + {last};\n}}")?;
writeln!(self.buf, "{acc} = {acc} + {body};")?;
} else {
writeln!(self.buf, "{acc} = c_mul({acc}, {last});\n}}")?;
writeln!(self.buf, "{acc} = c_mul({acc}, {body});")?;
}
writeln!(self.buf, "}}")?;
Ok(acc)
},
ExpressionType::Iter { itervar, count } => {
let init = expr.children.last().unwrap();
let init = &expr.children[0];
let itervar_fmt = format_local(itervar);
let v = self.compile_expr(local, init)?;
writeln!(self.buf, "var {itervar_fmt} = {v};")?;
@ -259,11 +281,9 @@ impl<'w, 'i, W: fmt::Write> Compiler<'w, 'i, W> {
writeln!(self.buf, "for(var {ivar}: i32 = 0; {ivar} < {count}; {ivar}++) {{")?;
let mut loop_local = local.clone();
loop_local.local_vars.insert(itervar);
let mut last = String::new();
for child in &expr.children[..expr.children.len() - 1] {
last = self.compile_expr(&mut loop_local, child)?;
}
writeln!(self.buf, "{itervar_fmt} = {last};\n}}")?;
let body = self.compile_expr(&mut loop_local, &expr.children[1])?;
writeln!(self.buf, "{itervar_fmt} = {body};")?;
writeln!(self.buf, "}}")?;
Ok(itervar_fmt)
}
}

View file

@ -21,10 +21,17 @@ extern {
"->" => Token::Arrow,
"=" => Token::Equal,
":" => Token::Colon,
">" => Token::Greater,
"<" => Token::Less,
">=" => Token::GreaterEqual,
"<=" => Token::LessEqual,
"==" => Token::EqualEqual,
"!=" => Token::BangEqual,
"\n" => Token::Newline,
"sum" => Token::Sum,
"prod" => Token::Prod,
"iter" => Token::Iter,
"if" => Token::If,
Float => Token::Float(<f64>),
Int => Token::Int(<i32>),
Name => Token::Name(<&'input str>),
@ -60,7 +67,21 @@ Exprs: Vec<Expression<'input>> = {
Expr: Expression<'input> = Store;
Store: Expression<'input> = {
<a:Store> "->" <n:Name> => Expression::new_store(a, n),
<a:Equality> "->" <n:Name> => Expression::new_store(a, n),
Equality,
}
Equality: Expression<'input> = {
<a:Compare> "==" <b:Compare> => Expression::new_binary(BinaryOp::Eq, a, b),
<a:Compare> "!=" <b:Compare> => Expression::new_binary(BinaryOp::Ne, a, b),
Compare,
}
Compare: Expression<'input> = {
<a:Sum> ">" <b:Sum> => Expression::new_binary(BinaryOp::Gt, a, b),
<a:Sum> "<" <b:Sum> => Expression::new_binary(BinaryOp::Lt, a, b),
<a:Sum> ">=" <b:Sum> => Expression::new_binary(BinaryOp::Ge, a, b),
<a:Sum> "<=" <b:Sum> => Expression::new_binary(BinaryOp::Le, a, b),
Sum,
}
@ -105,17 +126,23 @@ PreJuxtapose: Expression<'input> = {
"(" <Expr> ")",
}
Block: Expression<'input> = {
"{" <exs:Exprs> "}" => Expression::new_block(exs),
}
Item: Expression<'input> = {
Number,
<n:Name> => Expression::new_name(n),
"(" <Expr> ")",
"{" <exs:Exprs> "}" => Expression::new_block(exs),
"sum" "(" <name:Name> ":" <min:Int> "," <max:Int> ")" "{" <exs:Exprs> "}"
=> Expression::new_sum(name, min, max, exs),
"prod" "(" <name:Name> ":" <min:Int> "," <max:Int> ")" "{" <exs:Exprs> "}"
=> Expression::new_prod(name, min, max, exs),
"iter" "(" <count:Int> "," <name:Name> ":" <init:Expr> ")" "{" <exs:Exprs> "}"
=> Expression::new_iter(name, count, init, exs),
Block,
"sum" "(" <name:Name> ":" <min:Int> "," <max:Int> ")" <body:Block>
=> Expression::new_sum(name, min, max, body),
"prod" "(" <name:Name> ":" <min:Int> "," <max:Int> ")" <body:Block>
=> Expression::new_prod(name, min, max, body),
"iter" "(" <count:Int> "," <name:Name> ":" <init:Expr> ")" <body:Block>
=> Expression::new_iter(name, count, init, body),
"if" "(" <cond:Expr> ")" <t:Block> <f:Block>
=> Expression::new_if(cond, t, f),
}
Number: Expression<'input> = {

View file

@ -1,15 +1,19 @@
use std::{str::CharIndices, iter::Peekable, fmt};
use unicode_xid::UnicodeXID;
#[derive(Clone, Copy, Debug)]
pub enum Token<'i> {
Float(f64),
Int(i32),
Name(&'i str),
Sum, Prod, Iter,
Sum, Prod, Iter, If,
LParen, RParen,
LBrace, RBrace,
Plus, Minus, Star, Slash, Caret,
Greater, Less, GreaterEqual, LessEqual,
EqualEqual, BangEqual,
Comma, Arrow, Equal, Colon,
Newline,
}
@ -23,6 +27,7 @@ impl<'i> fmt::Display for Token<'i> {
Token::Sum => f.write_str("sum"),
Token::Prod => f.write_str("prod"),
Token::Iter => f.write_str("iter"),
Token::If => f.write_str("if"),
Token::LParen => f.write_str("("),
Token::RParen => f.write_str(")"),
Token::LBrace => f.write_str("{"),
@ -36,6 +41,12 @@ impl<'i> fmt::Display for Token<'i> {
Token::Arrow => f.write_str("->"),
Token::Equal => f.write_str("="),
Token::Colon => f.write_str(":"),
Token::Greater => f.write_str(">"),
Token::Less => f.write_str("<"),
Token::GreaterEqual => f.write_str(">="),
Token::LessEqual => f.write_str("<="),
Token::EqualEqual => f.write_str("=="),
Token::BangEqual => f.write_str("!="),
Token::Newline => f.write_str("newline")
}
}
@ -44,12 +55,14 @@ impl<'i> fmt::Display for Token<'i> {
#[derive(Clone, Copy, Debug)]
pub enum LexerError {
Unexpected(usize, char),
UnexpectedEof,
InvalidNumber(usize, usize),
}
impl fmt::Display for LexerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LexerError::UnexpectedEof => write!(f, "Unexpected EOF during lexing"),
LexerError::Unexpected(i, c) => write!(f, "Unexpected character {c:?} at {i}"),
LexerError::InvalidNumber(i, j) => write!(f, "Invalid number at {i}:{j}"),
}
@ -61,15 +74,15 @@ pub type Spanned<T, L, E> = Result<(L, T, L), E>;
pub struct Lexer<'i> {
src: &'i str,
chars: Peekable<CharIndices<'i>>,
bracket_depth: usize,
bracket_depth: isize,
}
fn is_ident_begin(c: char) -> bool {
c.is_alphabetic()
c.is_xid_start()
}
fn is_ident_middle(c: char) -> bool {
c.is_alphanumeric() || c == '_' || c == '\''
c.is_xid_continue() || matches!(c, '\'' | '\u{2080}'..='\u{2089}')
}
impl<'i> Lexer<'i> {
@ -118,6 +131,7 @@ impl<'i> Lexer<'i> {
"sum" => Ok((i, Token::Sum, j)),
"prod" => Ok((i, Token::Prod, j)),
"iter" => Ok((i, Token::Iter, j)),
"if" => Ok((i, Token::If, j)),
_ => Ok((i, Token::Name(s), j)),
}
}
@ -141,25 +155,48 @@ impl<'i> Lexer<'i> {
}
self.next_token()?
}
(i, '\n') => Ok((i, Token::Newline, i + 1)),
(i, '(') => { self.bracket_depth += 1; Ok((i, Token::LParen, i + 1)) },
(i, ')') => { self.bracket_depth -= 1; Ok((i, Token::RParen, i + 1)) },
(i, '{') => { self.bracket_depth += 1; Ok((i, Token::LBrace, i + 1)) },
(i, '}') => { self.bracket_depth -= 1; Ok((i, Token::RBrace, i + 1)) },
(i, '+') => Ok((i, Token::Plus, i + 1)),
(i, '-') => match self.chars.peek() {
Some((_, '>')) => {
self.chars.next();
Ok((i, Token::Arrow, i + 2))
},
(i, '-') => match self.chars.next_if(|(_, c)| *c == '>') {
Some(_) => Ok((i, Token::Arrow, i + 2)),
_ => Ok((i, Token::Minus, i + 1)),
}
},
(i, '*') => Ok((i, Token::Star, i + 1)),
(i, '\u{22C5}') => Ok((i, Token::Star, i + '\u{22C5}'.len_utf8())),
(i, '/') => Ok((i, Token::Slash, i + 1)),
(i, '^') => Ok((i, Token::Caret, i + 1)),
(i, '<') => match self.chars.next_if(|(_, c)| *c == '=') {
Some(_) => Ok((i, Token::LessEqual, i + 2)),
_ => Ok((i, Token::Less, i + 1)),
},
(i, '\u{2264}') => Ok((i, Token::LessEqual, i + '\u{2264}'.len_utf8())),
(i, '>') => match self.chars.next_if(|(_, c)| *c == '=') {
Some(_) => Ok((i, Token::GreaterEqual, i + 2)),
_ => Ok((i, Token::Greater, i + 1)),
},
(i, '\u{2265}') => Ok((i, Token::GreaterEqual, i + '\u{2265}'.len_utf8())),
(i, '=') => match self.chars.next_if(|(_, c)| *c == '=') {
Some(_) => Ok((i, Token::EqualEqual, i + 2)),
_ => Ok((i, Token::Equal, i + 1)),
}
(i, '!') => match self.chars.next() {
Some((_, '=')) => Ok((i, Token::BangEqual, i + 2)),
Some((_, c)) => Err(LexerError::Unexpected(i+1, c)),
None => Err(LexerError::UnexpectedEof),
}
(i, '\u{2260}') => Ok((i, Token::BangEqual, i + '\u{2260}'.len_utf8())),
(i, ',') => Ok((i, Token::Comma, i + 1)),
(i, '=') => Ok((i, Token::Equal, i + 1)),
(i, ':') => Ok((i, Token::Colon, i + 1)),
(i, '\n') => Ok((i, Token::Newline, i + 1)),
(i, '0'..='9') => self.next_number(i, false),
(i, '.') => self.next_number(i, true),
(i, c) if is_ident_begin(c) => self.next_word(i, i + c.len_utf8()),

View file

@ -15,6 +15,26 @@ struct Uniforms {
@group(0) @binding(1) var<uniform> uniforms: Uniforms;
/////////////////
// constants //
/////////////////
const TAU = 6.283185307179586;
const E = 2.718281828459045;
const RECIP_SQRT2 = 0.7071067811865475;
const LOG_TAU = 1.8378770664093453;
const LOG_2 = 0.6931471805599453;
const RECIP_SQRT29 = 0.18569533817705186;
const C_TAU = vec2f(TAU, 0.0);
const C_E = vec2f(E, 0.0);
const C_I = vec2f(0.0, 1.0);
const C_EMGAMMA = vec2f(0.5772156649015329, 0.0);
const C_PHI = vec2f(1.618033988749895, 0.0);
const C_ZERO = vec2f(0.0, 0.0);
const C_ONE = vec2f(1.0, 0.0);
///////////////
// utility //
///////////////
@ -43,23 +63,6 @@ fn vlength(v: vec2f) -> f32 {
return max(max(a, b), max(c, d));
}
/////////////////
// constants //
/////////////////
const TAU = 6.283185307179586;
const E = 2.718281828459045;
const RECIP_SQRT2 = 0.7071067811865475;
const LOG_TAU = 1.8378770664093453;
const LOG_2 = 0.6931471805599453;
const RECIP_SQRT29 = 0.18569533817705186;
const C_TAU = vec2f(TAU, 0.0);
const C_E = vec2f(E, 0.0);
const C_I = vec2f(0.0, 1.0);
const C_EMGAMMA = vec2f(0.5772156649015329, 0.0);
const C_PHI = vec2f(1.618033988749895, 0.0);
/////////////////////////
// complex functions //
/////////////////////////
@ -80,32 +83,16 @@ fn c_signim(z: vec2f) -> vec2f {
return vec2(sign(z.y), 0.0);
}
fn c_ifgt(p: vec2f, q: vec2f, z: vec2f, w: vec2f) -> vec2f {
return select(w, z, p.x > q.x);
fn c_absre(z: vec2f) -> vec2f {
return vec2(abs(z.x), 0.0);
}
fn c_iflt(p: vec2f, q: vec2f, z: vec2f, w: vec2f) -> vec2f {
return select(w, z, p.x < q.x);
fn c_absim(z: vec2f) -> vec2f {
return vec2(abs(z.y), 0.0);
}
fn c_ifge(p: vec2f, q: vec2f, z: vec2f, w: vec2f) -> vec2f {
return select(w, z, p.x >= q.x);
}
fn c_ifle(p: vec2f, q: vec2f, z: vec2f, w: vec2f) -> vec2f {
return select(w, z, p.x <= q.x);
}
fn c_ifeq(p: vec2f, q: vec2f, z: vec2f, w: vec2f) -> vec2f {
return select(w, z, p.x == q.x);
}
fn c_ifne(p: vec2f, q: vec2f, z: vec2f, w: vec2f) -> vec2f {
return select(w, z, p.x != q.x);
}
fn c_ifnan(p: vec2f, z: vec2f, w: vec2f) -> vec2f {
return select(w, z, p.x != p.x && p.y != p.y);
fn c_isnan(z: vec2f) -> vec2f {
return select(C_ZERO, C_ONE, z.x != z.x || z.y != z.y);
}
fn c_conj(z: vec2f) -> vec2f {
@ -128,9 +115,12 @@ fn c_arg(z: vec2f) -> vec2f {
}
fn c_argbr(z: vec2f, br: vec2f) -> vec2f {
if z.x < 0.0 && z.y == 0.0 {
return vec2(TAU/2.0 + floor(br.x/TAU) * TAU, 0.0);
}
let r = vec2(cos(-br.x), sin(-br.x));
let zr = c_mul(z, r);
return c_arg(zr) + vec2(br.x, 0.0);
return vec2(br.x + atan2(zr.y, zr.x), 0.0);
}
fn c_add(u: vec2f, v: vec2f) -> vec2f {
@ -222,15 +212,18 @@ fn c_tanh(z: vec2f) -> vec2f {
}
fn c_asin(z: vec2f) -> vec2f {
let m = select(-1.0, 1.0, z.y < 0.0 || (z.y == 0.0 && z.x > 0.0));
let u = c_sqrt(vec2(1.0, 0.0) - c_mul(z, z));
let v = c_log(u + vec2(-z.y, z.x));
return vec2(v.y, -v.x);
let v = c_log(u + m*vec2(-z.y, z.x));
return m*vec2(v.y, -v.x);
}
// TODO fix
fn c_acos(z: vec2f) -> vec2f {
let m = select(-1.0, 1.0, z.y < 0.0 || (z.y == 0.0 && z.x > 0.0));
let u = c_sqrt(vec2(1.0, 0.0) - c_mul(z, z));
let v = c_log(u + vec2(-z.y, z.x));
return vec2(TAU*0.25 - v.y, v.x);
let v = c_log(u + m*vec2(-z.y, z.x));
return C_TAU/4.0 + m*vec2(-v.y, v.x);
}
fn c_atan(z: vec2f) -> vec2f {
@ -241,19 +234,19 @@ fn c_atan(z: vec2f) -> vec2f {
}
fn c_asinh(z: vec2f) -> vec2f {
let m = select(-1.0, 1.0, z.x > 0.0 || (z.x == 0.0 && z.y > 0.0));
let u = c_sqrt(vec2(1.0, 0.0) + c_mul(z, z));
return c_log(u + z);
return c_log(u + z*m) * m;
}
fn c_acosh(z: vec2f) -> vec2f {
let u = c_sqrt(vec2(-1.0, 0.0) + c_mul(z, z));
let b = select(0.0, TAU, z.x < 0.0 || (z.x == 0.0 && z.y < 0.0));
let u = c_sqrtbr(vec2(-1.0, 0.0) + c_mul(z, z), vec2(b, 0.0));
return c_log(u + z);
}
fn c_atanh(z: vec2f) -> vec2f {
let u = vec2(1.0, 0.0) + z;
let v = vec2(1.0, 0.0) - z;
return 0.5 * c_log(c_div(u, v));
return 0.5 * (c_log(C_ONE + z) - c_log(C_ONE - z));
}
// log gamma //
@ -321,6 +314,56 @@ fn c_digamma_inner2(z: vec2f) -> vec2f {
return w - l;
}
// lambert w //
fn c_lambertw(z: vec2f) -> vec2f {
var w = c_lambertw_init(z, 0.0);
return c_lambertw_iter(z, w);
}
fn c_lambertwbr(z: vec2f, br: vec2f) -> vec2f {
// branch number
let br_n = br.x / TAU;
// if -TAU/2 < br < TAU/2 then use -1/e as the branch point,
// otherwise use 0
let branch_point = select(C_ZERO, vec2(-1.0/E, 0.0), abs(br.x) < TAU / 2.0);
let arg = c_arg(z - branch_point).x;
// if we're past the branch cut then take branch ceil(br_n),
// otherwise take branch floor(br_n)
let take_ceil = (br_n - floor(br_n) >= arg / TAU + 0.5);
var init_br = select(floor(br_n), ceil(br_n), take_ceil);
var w = c_lambertw_init(z, init_br);
// newton's method
return c_lambertw_iter(z, w);
}
fn c_lambertw_iter(z: vec2f, init: vec2f) -> vec2f {
var w = init;
for(var i = 0; i < 5; i++) {
w = c_div(c_mul(w, w) + c_mul(z, c_exp(-w)), w + C_ONE);
}
return w;
}
fn c_lambertw_init(z: vec2f, br: f32) -> vec2f {
let b = vec2(TAU * br, 0.0);
let oz = z + vec2(1.25, 0.0);
if br == 0.0 && dot(z, z) <= 50.0
|| br == 1.0 && z.y < 0.0 && dot(oz, oz) < 1.0
|| br == -1.0 && z.y > 0.0 && dot(oz, oz) < 1.0 {
// accurate near 0, near principle branch
let w = C_ONE + c_sqrtbr(C_ONE + E*z, b);
return c_div(c_mul(E*z, c_log(w)), w + E*z);
} else {
// accurate asymptotically
let logz = c_logbr(z, b);
return logz - c_log(logz);
}
}
/////////////////
// rendering //
/////////////////