1 use std::fmt;
2 use std::fmt::{Debug, Display, Formatter};
3
4 use std::str::FromStr;
5
6 use winnow::prelude::*;
7 use winnow::{
8 ascii::{digit1 as digits, multispace0 as multispaces},
9 combinator::alt,
10 combinator::dispatch,
11 combinator::fail,
12 combinator::peek,
13 combinator::repeat,
14 combinator::{delimited, preceded, terminated},
15 token::any,
16 token::one_of,
17 };
18
19 #[derive(Debug, Clone)]
20 pub enum Expr {
21 Value(i64),
22 Add(Box<Expr>, Box<Expr>),
23 Sub(Box<Expr>, Box<Expr>),
24 Mul(Box<Expr>, Box<Expr>),
25 Div(Box<Expr>, Box<Expr>),
26 Paren(Box<Expr>),
27 }
28
29 impl Expr {
eval(&self) -> i6430 pub fn eval(&self) -> i64 {
31 match self {
32 Self::Value(v) => *v,
33 Self::Add(lhs, rhs) => lhs.eval() + rhs.eval(),
34 Self::Sub(lhs, rhs) => lhs.eval() - rhs.eval(),
35 Self::Mul(lhs, rhs) => lhs.eval() * rhs.eval(),
36 Self::Div(lhs, rhs) => lhs.eval() / rhs.eval(),
37 Self::Paren(expr) => expr.eval(),
38 }
39 }
40 }
41
42 impl Display for Expr {
fmt(&self, format: &mut Formatter<'_>) -> fmt::Result43 fn fmt(&self, format: &mut Formatter<'_>) -> fmt::Result {
44 use Expr::{Add, Div, Mul, Paren, Sub, Value};
45 match *self {
46 Value(val) => write!(format, "{}", val),
47 Add(ref left, ref right) => write!(format, "{} + {}", left, right),
48 Sub(ref left, ref right) => write!(format, "{} - {}", left, right),
49 Mul(ref left, ref right) => write!(format, "{} * {}", left, right),
50 Div(ref left, ref right) => write!(format, "{} / {}", left, right),
51 Paren(ref expr) => write!(format, "({})", expr),
52 }
53 }
54 }
55
56 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
57 pub enum Token {
58 Value(i64),
59 Oper(Oper),
60 OpenParen,
61 CloseParen,
62 }
63
64 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
65 pub enum Oper {
66 Add,
67 Sub,
68 Mul,
69 Div,
70 }
71
72 impl winnow::stream::ContainsToken<Token> for Token {
73 #[inline(always)]
contains_token(&self, token: Token) -> bool74 fn contains_token(&self, token: Token) -> bool {
75 *self == token
76 }
77 }
78
79 impl winnow::stream::ContainsToken<Token> for &'_ [Token] {
80 #[inline]
contains_token(&self, token: Token) -> bool81 fn contains_token(&self, token: Token) -> bool {
82 self.iter().any(|t| *t == token)
83 }
84 }
85
86 impl<const LEN: usize> winnow::stream::ContainsToken<Token> for &'_ [Token; LEN] {
87 #[inline]
contains_token(&self, token: Token) -> bool88 fn contains_token(&self, token: Token) -> bool {
89 self.iter().any(|t| *t == token)
90 }
91 }
92
93 impl<const LEN: usize> winnow::stream::ContainsToken<Token> for [Token; LEN] {
94 #[inline]
contains_token(&self, token: Token) -> bool95 fn contains_token(&self, token: Token) -> bool {
96 self.iter().any(|t| *t == token)
97 }
98 }
99
100 #[allow(dead_code)]
expr2(i: &mut &str) -> PResult<Expr>101 pub fn expr2(i: &mut &str) -> PResult<Expr> {
102 let tokens = lex.parse_next(i)?;
103 expr.parse_next(&mut tokens.as_slice())
104 }
105
lex(i: &mut &str) -> PResult<Vec<Token>>106 pub fn lex(i: &mut &str) -> PResult<Vec<Token>> {
107 preceded(multispaces, repeat(1.., terminated(token, multispaces))).parse_next(i)
108 }
109
token(i: &mut &str) -> PResult<Token>110 fn token(i: &mut &str) -> PResult<Token> {
111 dispatch! {peek(any);
112 '0'..='9' => digits.try_map(FromStr::from_str).map(Token::Value),
113 '(' => '('.value(Token::OpenParen),
114 ')' => ')'.value(Token::CloseParen),
115 '+' => '+'.value(Token::Oper(Oper::Add)),
116 '-' => '-'.value(Token::Oper(Oper::Sub)),
117 '*' => '*'.value(Token::Oper(Oper::Mul)),
118 '/' => '/'.value(Token::Oper(Oper::Div)),
119 _ => fail,
120 }
121 .parse_next(i)
122 }
123
expr(i: &mut &[Token]) -> PResult<Expr>124 pub fn expr(i: &mut &[Token]) -> PResult<Expr> {
125 let init = term.parse_next(i)?;
126
127 repeat(
128 0..,
129 (
130 one_of([Token::Oper(Oper::Add), Token::Oper(Oper::Sub)]),
131 term,
132 ),
133 )
134 .fold(
135 move || init.clone(),
136 |acc, (op, val): (Token, Expr)| {
137 if op == Token::Oper(Oper::Add) {
138 Expr::Add(Box::new(acc), Box::new(val))
139 } else {
140 Expr::Sub(Box::new(acc), Box::new(val))
141 }
142 },
143 )
144 .parse_next(i)
145 }
146
term(i: &mut &[Token]) -> PResult<Expr>147 fn term(i: &mut &[Token]) -> PResult<Expr> {
148 let init = factor.parse_next(i)?;
149
150 repeat(
151 0..,
152 (
153 one_of([Token::Oper(Oper::Mul), Token::Oper(Oper::Div)]),
154 factor,
155 ),
156 )
157 .fold(
158 move || init.clone(),
159 |acc, (op, val): (Token, Expr)| {
160 if op == Token::Oper(Oper::Mul) {
161 Expr::Mul(Box::new(acc), Box::new(val))
162 } else {
163 Expr::Div(Box::new(acc), Box::new(val))
164 }
165 },
166 )
167 .parse_next(i)
168 }
169
factor(i: &mut &[Token]) -> PResult<Expr>170 fn factor(i: &mut &[Token]) -> PResult<Expr> {
171 alt((
172 one_of(|t| matches!(t, Token::Value(_))).map(|t| match t {
173 Token::Value(v) => Expr::Value(v),
174 _ => unreachable!(),
175 }),
176 parens,
177 ))
178 .parse_next(i)
179 }
180
parens(i: &mut &[Token]) -> PResult<Expr>181 fn parens(i: &mut &[Token]) -> PResult<Expr> {
182 delimited(one_of(Token::OpenParen), expr, one_of(Token::CloseParen))
183 .map(|e| Expr::Paren(Box::new(e)))
184 .parse_next(i)
185 }
186
187 #[test]
lex_test()188 fn lex_test() {
189 let input = "3";
190 let expected = Ok(String::from(r#"("", [Value(3)])"#));
191 assert_eq!(lex.parse_peek(input).map(|e| format!("{e:?}")), expected);
192
193 let input = " 24 ";
194 let expected = Ok(String::from(r#"("", [Value(24)])"#));
195 assert_eq!(lex.parse_peek(input).map(|e| format!("{e:?}")), expected);
196
197 let input = " 12 *2 / 3";
198 let expected = Ok(String::from(
199 r#"("", [Value(12), Oper(Mul), Value(2), Oper(Div), Value(3)])"#,
200 ));
201 assert_eq!(lex.parse_peek(input).map(|e| format!("{e:?}")), expected);
202
203 let input = " 2*2 / ( 5 - 1) + 3";
204 let expected = Ok(String::from(
205 r#"("", [Value(2), Oper(Mul), Value(2), Oper(Div), OpenParen, Value(5), Oper(Sub), Value(1), CloseParen, Oper(Add), Value(3)])"#,
206 ));
207 assert_eq!(lex.parse_peek(input).map(|e| format!("{e:?}")), expected);
208 }
209
210 #[test]
factor_test()211 fn factor_test() {
212 let input = "3";
213 let expected = Ok(String::from("Value(3)"));
214 let input = lex.parse(input).unwrap();
215 assert_eq!(factor.map(|e| format!("{e:?}")).parse(&input), expected);
216
217 let input = " 12";
218 let expected = Ok(String::from("Value(12)"));
219 let input = lex.parse(input).unwrap();
220 assert_eq!(factor.map(|e| format!("{e:?}")).parse(&input), expected);
221
222 let input = "537 ";
223 let expected = Ok(String::from("Value(537)"));
224 let input = lex.parse(input).unwrap();
225 assert_eq!(factor.map(|e| format!("{e:?}")).parse(&input), expected);
226
227 let input = " 24 ";
228 let expected = Ok(String::from("Value(24)"));
229 let input = lex.parse(input).unwrap();
230 assert_eq!(factor.map(|e| format!("{e:?}")).parse(&input), expected);
231 }
232
233 #[test]
term_test()234 fn term_test() {
235 let input = " 12 *2 / 3";
236 let expected = Ok(String::from("Div(Mul(Value(12), Value(2)), Value(3))"));
237 let input = lex.parse(input).unwrap();
238 assert_eq!(term.map(|e| format!("{e:?}")).parse(&input), expected);
239
240 let input = " 12 *2 / 3";
241 let expected = Ok(String::from("Div(Mul(Value(12), Value(2)), Value(3))"));
242 let input = lex.parse(input).unwrap();
243 assert_eq!(term.map(|e| format!("{e:?}")).parse(&input), expected);
244
245 let input = " 2* 3 *2 *2 / 3";
246 let expected = Ok(String::from(
247 "Div(Mul(Mul(Mul(Value(2), Value(3)), Value(2)), Value(2)), Value(3))",
248 ));
249 let input = lex.parse(input).unwrap();
250 assert_eq!(term.map(|e| format!("{e:?}")).parse(&input), expected);
251
252 let input = " 48 / 3/2";
253 let expected = Ok(String::from("Div(Div(Value(48), Value(3)), Value(2))"));
254 let input = lex.parse(input).unwrap();
255 assert_eq!(term.map(|e| format!("{e:?}")).parse(&input), expected);
256 }
257
258 #[test]
expr_test()259 fn expr_test() {
260 let input = " 1 + 2 ";
261 let expected = Ok(String::from("Add(Value(1), Value(2))"));
262 let input = lex.parse(input).unwrap();
263 assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
264
265 let input = " 12 + 6 - 4+ 3";
266 let expected = Ok(String::from(
267 "Add(Sub(Add(Value(12), Value(6)), Value(4)), Value(3))",
268 ));
269 let input = lex.parse(input).unwrap();
270 assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
271
272 let input = " 1 + 2*3 + 4";
273 let expected = Ok(String::from(
274 "Add(Add(Value(1), Mul(Value(2), Value(3))), Value(4))",
275 ));
276 let input = lex.parse(input).unwrap();
277 assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
278 }
279
280 #[test]
parens_test()281 fn parens_test() {
282 let input = " ( 2 )";
283 let expected = Ok(String::from("Paren(Value(2))"));
284 let input = lex.parse(input).unwrap();
285 assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
286
287 let input = " 2* ( 3 + 4 ) ";
288 let expected = Ok(String::from(
289 "Mul(Value(2), Paren(Add(Value(3), Value(4))))",
290 ));
291 let input = lex.parse(input).unwrap();
292 assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
293
294 let input = " 2*2 / ( 5 - 1) + 3";
295 let expected = Ok(String::from(
296 "Add(Div(Mul(Value(2), Value(2)), Paren(Sub(Value(5), Value(1)))), Value(3))",
297 ));
298 let input = lex.parse(input).unwrap();
299 assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
300 }
301