147 lines
4.3 KiB
Python
147 lines
4.3 KiB
Python
import re
|
|
from datetime import date
|
|
from typing import Any
|
|
|
|
from .exceptions import EvaluationError
|
|
from .nodes import (
|
|
BinaryOp,
|
|
Comparison,
|
|
Expression,
|
|
Identifier,
|
|
ListLiteral,
|
|
Literal,
|
|
UnaryOp,
|
|
)
|
|
|
|
_MISSING = object()
|
|
|
|
_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
|
|
|
|
|
|
def _try_date(value: Any) -> date | None:
|
|
if isinstance(value, date):
|
|
return value
|
|
if isinstance(value, str) and _DATE_RE.match(value):
|
|
try:
|
|
return date.fromisoformat(value)
|
|
except ValueError:
|
|
return None
|
|
return None
|
|
|
|
|
|
def _coerce(left: Any, right: Any) -> tuple[Any, Any]:
|
|
if type(left) is type(right):
|
|
return left, right
|
|
|
|
left_date = _try_date(left)
|
|
right_date = _try_date(right)
|
|
if left_date is not None and right_date is not None:
|
|
return left_date, right_date
|
|
|
|
return left, right
|
|
|
|
|
|
def _values_equal(left: Any, right: Any) -> bool:
|
|
coerced_left, coerced_right = _coerce(left, right)
|
|
try:
|
|
return coerced_left == coerced_right
|
|
except TypeError:
|
|
return False
|
|
|
|
|
|
class Evaluator:
|
|
def __init__(self, variables: dict[str, Any]) -> None:
|
|
self._variables = variables
|
|
|
|
def evaluate(self, node: Expression) -> bool:
|
|
return self._as_bool(self._eval(node))
|
|
|
|
def _as_bool(self, value: Any) -> bool:
|
|
if value is _MISSING:
|
|
return False
|
|
return bool(value)
|
|
|
|
def _eval(self, node: Expression) -> Any:
|
|
match node:
|
|
case Literal(value=value):
|
|
return value
|
|
case Identifier(name=name):
|
|
return self._resolve(name)
|
|
case UnaryOp(operator="NOT", operand=operand):
|
|
value = self._eval(operand)
|
|
if value is _MISSING:
|
|
return False
|
|
return not self._as_bool(value)
|
|
case BinaryOp(operator="AND", left=left, right=right):
|
|
if not self._as_bool(self._eval(left)):
|
|
return False
|
|
return self._as_bool(self._eval(right))
|
|
case BinaryOp(operator="OR", left=left, right=right):
|
|
if self._as_bool(self._eval(left)):
|
|
return True
|
|
return self._as_bool(self._eval(right))
|
|
case Comparison() as comp:
|
|
return self._compare(comp)
|
|
case _:
|
|
msg = f"Unknown node: {type(node).__name__}"
|
|
raise EvaluationError(msg)
|
|
|
|
def _resolve(self, name: str) -> Any:
|
|
parts = name.split(".")
|
|
current: Any = self._variables
|
|
for part in parts:
|
|
if not isinstance(current, dict):
|
|
return _MISSING
|
|
if part not in current:
|
|
return _MISSING
|
|
current = current[part]
|
|
return current
|
|
|
|
def _compare(self, node: Comparison) -> bool:
|
|
left = self._eval(node.left)
|
|
if left is _MISSING:
|
|
return False
|
|
|
|
if node.operator in {"IN", "NOT IN"}:
|
|
return self._membership(left, node)
|
|
|
|
right = self._eval(node.right)
|
|
if right is _MISSING:
|
|
return False
|
|
|
|
left, right = _coerce(left, right)
|
|
return self._apply_operator(node.operator, left, right)
|
|
|
|
def _membership(self, left: Any, node: Comparison) -> bool:
|
|
if not isinstance(node.right, ListLiteral):
|
|
msg = "IN/NOT IN requires a list"
|
|
raise EvaluationError(msg)
|
|
|
|
items = [self._eval(item) for item in node.right.items]
|
|
found = any(_values_equal(left, item) for item in items)
|
|
|
|
if node.operator == "IN":
|
|
return found
|
|
return not found
|
|
|
|
def _apply_operator(self, operator: str, left: Any, right: Any) -> bool:
|
|
try:
|
|
match operator:
|
|
case "==":
|
|
return left == right
|
|
case "!=":
|
|
return left != right
|
|
case ">":
|
|
return left > right
|
|
case ">=":
|
|
return left >= right
|
|
case "<":
|
|
return left < right
|
|
case "<=":
|
|
return left <= right
|
|
case _:
|
|
msg = f"Unknown operator: {operator}"
|
|
raise EvaluationError(msg)
|
|
except TypeError:
|
|
return False
|