1 use std::fmt;
2 use std::fmt::{Debug, Display, Formatter};
3 
4 use std::str::FromStr;
5 
6 use winnow::prelude::*;
7 use winnow::{
8     ascii::{digit1 as digits, multispace0 as multispaces},
9     combinator::alt,
10     combinator::dispatch,
11     combinator::fail,
12     combinator::peek,
13     combinator::repeat,
14     combinator::{delimited, preceded, terminated},
15     token::any,
16     token::one_of,
17 };
18 
19 #[derive(Debug, Clone)]
20 pub enum Expr {
21     Value(i64),
22     Add(Box<Expr>, Box<Expr>),
23     Sub(Box<Expr>, Box<Expr>),
24     Mul(Box<Expr>, Box<Expr>),
25     Div(Box<Expr>, Box<Expr>),
26     Paren(Box<Expr>),
27 }
28 
29 impl Expr {
eval(&self) -> i6430     pub fn eval(&self) -> i64 {
31         match self {
32             Self::Value(v) => *v,
33             Self::Add(lhs, rhs) => lhs.eval() + rhs.eval(),
34             Self::Sub(lhs, rhs) => lhs.eval() - rhs.eval(),
35             Self::Mul(lhs, rhs) => lhs.eval() * rhs.eval(),
36             Self::Div(lhs, rhs) => lhs.eval() / rhs.eval(),
37             Self::Paren(expr) => expr.eval(),
38         }
39     }
40 }
41 
42 impl Display for Expr {
fmt(&self, format: &mut Formatter<'_>) -> fmt::Result43     fn fmt(&self, format: &mut Formatter<'_>) -> fmt::Result {
44         use Expr::{Add, Div, Mul, Paren, Sub, Value};
45         match *self {
46             Value(val) => write!(format, "{}", val),
47             Add(ref left, ref right) => write!(format, "{} + {}", left, right),
48             Sub(ref left, ref right) => write!(format, "{} - {}", left, right),
49             Mul(ref left, ref right) => write!(format, "{} * {}", left, right),
50             Div(ref left, ref right) => write!(format, "{} / {}", left, right),
51             Paren(ref expr) => write!(format, "({})", expr),
52         }
53     }
54 }
55 
56 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
57 pub enum Token {
58     Value(i64),
59     Oper(Oper),
60     OpenParen,
61     CloseParen,
62 }
63 
64 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
65 pub enum Oper {
66     Add,
67     Sub,
68     Mul,
69     Div,
70 }
71 
72 impl winnow::stream::ContainsToken<Token> for Token {
73     #[inline(always)]
contains_token(&self, token: Token) -> bool74     fn contains_token(&self, token: Token) -> bool {
75         *self == token
76     }
77 }
78 
79 impl winnow::stream::ContainsToken<Token> for &'_ [Token] {
80     #[inline]
contains_token(&self, token: Token) -> bool81     fn contains_token(&self, token: Token) -> bool {
82         self.iter().any(|t| *t == token)
83     }
84 }
85 
86 impl<const LEN: usize> winnow::stream::ContainsToken<Token> for &'_ [Token; LEN] {
87     #[inline]
contains_token(&self, token: Token) -> bool88     fn contains_token(&self, token: Token) -> bool {
89         self.iter().any(|t| *t == token)
90     }
91 }
92 
93 impl<const LEN: usize> winnow::stream::ContainsToken<Token> for [Token; LEN] {
94     #[inline]
contains_token(&self, token: Token) -> bool95     fn contains_token(&self, token: Token) -> bool {
96         self.iter().any(|t| *t == token)
97     }
98 }
99 
100 #[allow(dead_code)]
expr2(i: &mut &str) -> PResult<Expr>101 pub fn expr2(i: &mut &str) -> PResult<Expr> {
102     let tokens = lex.parse_next(i)?;
103     expr.parse_next(&mut tokens.as_slice())
104 }
105 
lex(i: &mut &str) -> PResult<Vec<Token>>106 pub fn lex(i: &mut &str) -> PResult<Vec<Token>> {
107     preceded(multispaces, repeat(1.., terminated(token, multispaces))).parse_next(i)
108 }
109 
token(i: &mut &str) -> PResult<Token>110 fn token(i: &mut &str) -> PResult<Token> {
111     dispatch! {peek(any);
112         '0'..='9' => digits.try_map(FromStr::from_str).map(Token::Value),
113         '(' => '('.value(Token::OpenParen),
114         ')' => ')'.value(Token::CloseParen),
115         '+' => '+'.value(Token::Oper(Oper::Add)),
116         '-' => '-'.value(Token::Oper(Oper::Sub)),
117         '*' => '*'.value(Token::Oper(Oper::Mul)),
118         '/' => '/'.value(Token::Oper(Oper::Div)),
119         _ => fail,
120     }
121     .parse_next(i)
122 }
123 
expr(i: &mut &[Token]) -> PResult<Expr>124 pub fn expr(i: &mut &[Token]) -> PResult<Expr> {
125     let init = term.parse_next(i)?;
126 
127     repeat(
128         0..,
129         (
130             one_of([Token::Oper(Oper::Add), Token::Oper(Oper::Sub)]),
131             term,
132         ),
133     )
134     .fold(
135         move || init.clone(),
136         |acc, (op, val): (Token, Expr)| {
137             if op == Token::Oper(Oper::Add) {
138                 Expr::Add(Box::new(acc), Box::new(val))
139             } else {
140                 Expr::Sub(Box::new(acc), Box::new(val))
141             }
142         },
143     )
144     .parse_next(i)
145 }
146 
term(i: &mut &[Token]) -> PResult<Expr>147 fn term(i: &mut &[Token]) -> PResult<Expr> {
148     let init = factor.parse_next(i)?;
149 
150     repeat(
151         0..,
152         (
153             one_of([Token::Oper(Oper::Mul), Token::Oper(Oper::Div)]),
154             factor,
155         ),
156     )
157     .fold(
158         move || init.clone(),
159         |acc, (op, val): (Token, Expr)| {
160             if op == Token::Oper(Oper::Mul) {
161                 Expr::Mul(Box::new(acc), Box::new(val))
162             } else {
163                 Expr::Div(Box::new(acc), Box::new(val))
164             }
165         },
166     )
167     .parse_next(i)
168 }
169 
factor(i: &mut &[Token]) -> PResult<Expr>170 fn factor(i: &mut &[Token]) -> PResult<Expr> {
171     alt((
172         one_of(|t| matches!(t, Token::Value(_))).map(|t| match t {
173             Token::Value(v) => Expr::Value(v),
174             _ => unreachable!(),
175         }),
176         parens,
177     ))
178     .parse_next(i)
179 }
180 
parens(i: &mut &[Token]) -> PResult<Expr>181 fn parens(i: &mut &[Token]) -> PResult<Expr> {
182     delimited(one_of(Token::OpenParen), expr, one_of(Token::CloseParen))
183         .map(|e| Expr::Paren(Box::new(e)))
184         .parse_next(i)
185 }
186 
187 #[test]
lex_test()188 fn lex_test() {
189     let input = "3";
190     let expected = Ok(String::from(r#"("", [Value(3)])"#));
191     assert_eq!(lex.parse_peek(input).map(|e| format!("{e:?}")), expected);
192 
193     let input = "  24     ";
194     let expected = Ok(String::from(r#"("", [Value(24)])"#));
195     assert_eq!(lex.parse_peek(input).map(|e| format!("{e:?}")), expected);
196 
197     let input = " 12 *2 /  3";
198     let expected = Ok(String::from(
199         r#"("", [Value(12), Oper(Mul), Value(2), Oper(Div), Value(3)])"#,
200     ));
201     assert_eq!(lex.parse_peek(input).map(|e| format!("{e:?}")), expected);
202 
203     let input = "  2*2 / ( 5 - 1) + 3";
204     let expected = Ok(String::from(
205         r#"("", [Value(2), Oper(Mul), Value(2), Oper(Div), OpenParen, Value(5), Oper(Sub), Value(1), CloseParen, Oper(Add), Value(3)])"#,
206     ));
207     assert_eq!(lex.parse_peek(input).map(|e| format!("{e:?}")), expected);
208 }
209 
210 #[test]
factor_test()211 fn factor_test() {
212     let input = "3";
213     let expected = Ok(String::from("Value(3)"));
214     let input = lex.parse(input).unwrap();
215     assert_eq!(factor.map(|e| format!("{e:?}")).parse(&input), expected);
216 
217     let input = " 12";
218     let expected = Ok(String::from("Value(12)"));
219     let input = lex.parse(input).unwrap();
220     assert_eq!(factor.map(|e| format!("{e:?}")).parse(&input), expected);
221 
222     let input = "537 ";
223     let expected = Ok(String::from("Value(537)"));
224     let input = lex.parse(input).unwrap();
225     assert_eq!(factor.map(|e| format!("{e:?}")).parse(&input), expected);
226 
227     let input = "  24     ";
228     let expected = Ok(String::from("Value(24)"));
229     let input = lex.parse(input).unwrap();
230     assert_eq!(factor.map(|e| format!("{e:?}")).parse(&input), expected);
231 }
232 
233 #[test]
term_test()234 fn term_test() {
235     let input = " 12 *2 /  3";
236     let expected = Ok(String::from("Div(Mul(Value(12), Value(2)), Value(3))"));
237     let input = lex.parse(input).unwrap();
238     assert_eq!(term.map(|e| format!("{e:?}")).parse(&input), expected);
239 
240     let input = " 12 *2 /  3";
241     let expected = Ok(String::from("Div(Mul(Value(12), Value(2)), Value(3))"));
242     let input = lex.parse(input).unwrap();
243     assert_eq!(term.map(|e| format!("{e:?}")).parse(&input), expected);
244 
245     let input = " 2* 3  *2 *2 /  3";
246     let expected = Ok(String::from(
247         "Div(Mul(Mul(Mul(Value(2), Value(3)), Value(2)), Value(2)), Value(3))",
248     ));
249     let input = lex.parse(input).unwrap();
250     assert_eq!(term.map(|e| format!("{e:?}")).parse(&input), expected);
251 
252     let input = " 48 /  3/2";
253     let expected = Ok(String::from("Div(Div(Value(48), Value(3)), Value(2))"));
254     let input = lex.parse(input).unwrap();
255     assert_eq!(term.map(|e| format!("{e:?}")).parse(&input), expected);
256 }
257 
258 #[test]
expr_test()259 fn expr_test() {
260     let input = " 1 +  2 ";
261     let expected = Ok(String::from("Add(Value(1), Value(2))"));
262     let input = lex.parse(input).unwrap();
263     assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
264 
265     let input = " 12 + 6 - 4+  3";
266     let expected = Ok(String::from(
267         "Add(Sub(Add(Value(12), Value(6)), Value(4)), Value(3))",
268     ));
269     let input = lex.parse(input).unwrap();
270     assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
271 
272     let input = " 1 + 2*3 + 4";
273     let expected = Ok(String::from(
274         "Add(Add(Value(1), Mul(Value(2), Value(3))), Value(4))",
275     ));
276     let input = lex.parse(input).unwrap();
277     assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
278 }
279 
280 #[test]
parens_test()281 fn parens_test() {
282     let input = " (  2 )";
283     let expected = Ok(String::from("Paren(Value(2))"));
284     let input = lex.parse(input).unwrap();
285     assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
286 
287     let input = " 2* (  3 + 4 ) ";
288     let expected = Ok(String::from(
289         "Mul(Value(2), Paren(Add(Value(3), Value(4))))",
290     ));
291     let input = lex.parse(input).unwrap();
292     assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
293 
294     let input = "  2*2 / ( 5 - 1) + 3";
295     let expected = Ok(String::from(
296         "Add(Div(Mul(Value(2), Value(2)), Paren(Sub(Value(5), Value(1)))), Value(3))",
297     ));
298     let input = lex.parse(input).unwrap();
299     assert_eq!(expr.map(|e| format!("{e:?}")).parse(&input), expected);
300 }
301