scripts: csv.py: RExpr decorators to help simplify func/uop/bop parsing

This was more tricky than expected since Python's class scope is so
funky (I just eneded up with using lazy cached __get__ functions that
scan the RExpr class for tagged members), but these decorators help avoid
repeated boilerplate for common expr patterns.

We can even deduplicate binary expr parsing without sacrificing
precedence.
This commit is contained in:
Christopher Haster
2024-11-11 16:14:13 -06:00
parent 4061891a02
commit ac0aa3633e

View File

@@ -272,7 +272,7 @@ class RExpr:
def __init__(self, reason):
self.reason = reason
# expr nodes
# expr node base class
class Expr:
def __init__(self, *args):
for k, v in zip('abcdefghijklmnopqrstuvwxyz', args):
@@ -303,6 +303,9 @@ class RExpr:
def eval(self, fields={}):
return self.a.eval(fields)
# expr nodes
# field expr
class Field(Expr):
def fields(self):
return {self.a}
@@ -322,6 +325,7 @@ class RExpr:
raise RExpr.Error("unknown field? %s" % self.a)
return fields[self.a]
# literal exprs
class StrLit(Expr):
def fields(self):
return set()
@@ -355,6 +359,23 @@ class RExpr:
def eval(self, fields={}):
return self.a
# func expr helper
def func(name):
def func(f):
f._func = name
return f
return func
class Funcs:
@ft.cache
def __get__(self, _, cls):
return {x._func: x
for x in cls.__dict__.values()
if hasattr(x, '_func')}
funcs = Funcs()
# type exprs
@func('int')
class Int(Expr):
def type(self, types={}):
return RInt
@@ -362,6 +383,7 @@ class RExpr:
def eval(self, fields={}):
return RInt(self.a.eval(fields))
@func('float')
class Float(Expr):
def type(self, types={}):
return RFloat
@@ -369,6 +391,7 @@ class RExpr:
def eval(self, fields={}):
return RFloat(self.a.eval(fields))
@func('frac')
class Frac(Expr):
def type(self, types={}):
return RFrac
@@ -376,151 +399,206 @@ class RExpr:
def eval(self, fields={}):
return RFrac(self.a.eval(fields), self.b.eval(fields))
# fold exprs
@func('sum')
class Sum(Expr):
def fold(self, types={}):
return RSum, self.a.type(types)
@func('prod')
class Prod(Expr):
def fold(self, types={}):
return RProd, self.a.type(types)
@func('min')
class Min(Expr):
def fold(self, types={}):
return RMin, self.a.type(types)
@func('max')
class Max(Expr):
def fold(self, types={}):
return RMax, self.a.type(types)
@func('avg')
class Avg(Expr):
def fold(self, types={}):
return RAvg, RFloat
@func('stddev')
class Stddev(Expr):
def fold(self, types={}):
return RStddev, RFloat
@func('gmean')
class GMean(Expr):
def fold(self, types={}):
return RGMean, RFloat
@func('stddev')
class GStddev(Expr):
def fold(self, types={}):
return RGStddev, RFloat
# functions
@func('ratio')
class Ratio(Expr):
pass
@func('total')
class Total(Expr):
pass
@func('ceil')
class Ceil(Expr):
pass
@func('floor')
class Floor(Expr):
pass
@func('log')
class Log(Expr):
pass
@func('pow')
class Pow(Expr):
pass
@func('sqrt')
class Sqrt(Expr):
pass
funcs = {
# types
'int': Int,
'float': Float,
'frac': Frac,
# unary expr helper
def uop(op):
def uop(f):
f._uop = op
return f
return uop
# functions
'ratio': Ratio,
'total': Total,
'ceil': Ceil,
'floor': Floor,
'log': Log,
'pow': Pow,
'sqrt': Sqrt,
# mergers
'sum': Sum,
'prod': Prod,
'min': Min,
'max': Max,
'avg': Avg,
'stddev': Stddev,
'gmean': GMean,
'gstddev': GStddev,
}
class UOps:
@ft.cache
def __get__(self, _, cls):
return {x._uop: x
for x in cls.__dict__.values()
if hasattr(x, '_uop')}
uops = UOps()
# unary ops
@uop('+')
class Pos(Expr):
pass
@uop('-')
class Neg(Expr):
pass
@uop('~')
class Not(Expr):
pass
@uop('!')
class Notnot(Expr):
pass
# binary expr help
def bop(op, prec):
def bop(f):
f._bop = op
f._bprec = prec
return f
return bop
class BOps:
@ft.cache
def __get__(self, _, cls):
return {x._bop: x
for x in cls.__dict__.values()
if hasattr(x, '_bop')}
bops = BOps()
class BPrecs:
@ft.cache
def __get__(self, _, cls):
return {x._bop: x._bprec
for x in cls.__dict__.values()
if hasattr(x, '_bop')}
bprecs = BPrecs()
# binary ops
@bop('*', 10)
class Mul(Expr):
pass
@bop('/', 10)
class Div(Expr):
pass
@bop('%', 10)
class Mod(Expr):
pass
@bop('+', 9)
class Add(Expr):
pass
@bop('-', 9)
class Sub(Expr):
pass
@bop('<<', 8)
class Shl(Expr):
pass
@bop('>>', 8)
class Shr(Expr):
pass
@bop('&', 7)
class And(Expr):
pass
@bop('^', 6)
class Xor(Expr):
pass
@bop('|', 5)
class Or(Expr):
pass
class Lt(Expr):
pass
class Le(Expr):
pass
class Gt(Expr):
pass
class Ge(Expr):
pass
class Ne(Expr):
pass
@bop('==', 4)
class Eq(Expr):
pass
@bop('!=', 4)
class Ne(Expr):
pass
@bop('<', 4)
class Lt(Expr):
pass
@bop('<=', 4)
class Le(Expr):
pass
@bop('>', 4)
class Gt(Expr):
pass
@bop('>=', 4)
class Ge(Expr):
pass
@bop('&&', 3)
class Andand(Expr):
pass
@bop('||', 2)
class Oror(Expr):
pass
# ternary ops
class Ife(Expr):
def type(self, types={}):
return self.b.type(types)
@@ -528,18 +606,20 @@ class RExpr:
def fold(self, types={}):
return self.b.fold(types)
# parse and expr
# parse an expr
def __init__(self, expr):
self.expr = expr.strip()
# parse the expression into a tree
def p_expr(expr, prec=0):
# parens
if expr.startswith('('):
a, tail = p_expr(expr[1:].lstrip())
if not tail.startswith(')'):
raise RExpr.Error("mismatched parens? %s" % tail)
tail = tail[1:].lstrip()
# fields/functions
elif re.match('[_a-zA-Z][_a-zA-Z0-9]*', expr):
m = re.match('[_a-zA-Z][_a-zA-Z0-9]*', expr)
tail = expr[len(m.group()):].lstrip()
@@ -566,113 +646,56 @@ class RExpr:
else:
a = RExpr.Field(m.group())
# strings
elif re.match('(?:"(?:\\.|[^"])*"|\'(?:\\.|[^\'])\')', expr):
m = re.match('(?:"(?:\\.|[^"])*"|\'(?:\\.|[^\'])\')', expr)
a = RExpr.StrLit(m.group()[1:-1])
tail = expr[len(m.group()):].lstrip()
# floats
elif re.match('[+-]?[_0-9]*\.[_0-9eE]', expr):
m = re.match('[+-]?[_0-9]*\.[_0-9eE]', expr)
a = RExpr.FloatLit(RFloat(m.group()))
tail = expr[len(m.group()):].lstrip()
# ints
elif re.match('[+-]?(?:(?:0[bBoOxX])?[_0-9a-fA-F]+|∞|inf)', expr):
m = re.match('[+-]?(?:(?:0[bBoOxX])?[_0-9a-fA-F]+|∞|inf)', expr)
a = RExpr.IntLit(RInt(m.group()))
tail = expr[len(m.group()):].lstrip()
elif expr.startswith('+'):
a, tail = p_expr(expr[1:].lstrip(), 12)
a = RExpr.Pos(a)
elif expr.startswith('-'):
a, tail = p_expr(expr[1:].lstrip(), 12)
a = RExpr.Neg(a)
elif expr.startswith('~'):
a, tail = p_expr(expr[1:].lstrip(), 12)
a = RExpr.Not(a)
elif expr.startswith('!'):
a, tail = p_expr(expr[1:].lstrip(), 4)
a = RExpr.Notnot(a)
# unary ops
elif any(expr.startswith(op) for op in RExpr.uops.keys()):
# sort by len to avoid ambiguities
for op in sorted(RExpr.uops.keys(), reverse=True):
if expr.startswith(op):
a, tail = p_expr(expr[len(op):].lstrip(), mt.inf())
a = RExpr.uops[op](a)
break
else:
assert False
# unknown expr?
else:
raise RExpr.Error("unknown expr? %s" % expr)
# parse tail
while True:
if tail.startswith('*') and prec < 11:
b, tail = p_expr(tail[1:].lstrip(), 11)
a = RExpr.Mul(a, b)
elif tail.startswith('/') and prec < 11:
b, tail = p_expr(tail[1:].lstrip(), 11)
a = RExpr.Div(a, b)
elif tail.startswith('%') and prec < 11:
b, tail = p_expr(tail[1:].lstrip(), 11)
a = RExpr.Mod(a, b)
elif tail.startswith('+') and prec < 10:
b, tail = p_expr(tail[1:].lstrip(), 10)
a = RExpr.Add(a, b)
elif tail.startswith('-') and prec < 10:
b, tail = p_expr(tail[1:].lstrip(), 10)
a = RExpr.Sub(a, b)
elif tail.startswith('<<') and prec < 9:
b, tail = p_expr(tail[2:].lstrip(), 9)
a = RExpr.Shl(a, b)
elif tail.startswith('>>') and prec < 9:
b, tail = p_expr(tail[2:].lstrip(), 9)
a = RExpr.Shr(a, b)
elif tail.startswith('&') and prec < 8:
b, tail = p_expr(tail[1:].lstrip(), 8)
a = RExpr.And(a, b)
elif tail.startswith('^') and prec < 7:
b, tail = p_expr(tail[1:].lstrip(), 7)
a = RExpr.Xor(a, b)
elif tail.startswith('|') and prec < 6:
b, tail = p_expr(tail[1:].lstrip(), 6)
a = RExpr.Or(a, b)
elif tail.startswith('<') and prec < 5:
b, tail = p_expr(tail[1:].lstrip(), 5)
a = RExpr.Lt(a, b)
elif tail.startswith('<=') and prec < 5:
b, tail = p_expr(tail[2:].lstrip(), 5)
a = RExpr.Le(a, b)
elif tail.startswith('>') and prec < 5:
b, tail = p_expr(tail[1:].lstrip(), 5)
a = RExpr.Gt(a, b)
elif tail.startswith('>=') and prec < 5:
b, tail = p_expr(tail[2:].lstrip(), 5)
a = RExpr.Ge(a, b)
elif tail.startswith('!=') and prec < 5:
b, tail = p_expr(tail[2:].lstrip(), 5)
a = RExpr.Ne(a, b)
elif tail.startswith('==') and prec < 5:
b, tail = p_expr(tail[2:].lstrip(), 5)
a = RExpr.Eq(a, b)
elif tail.startswith('&&') and prec < 3:
b, tail = p_expr(tail[2:].lstrip(), 3)
a = RExpr.Andand(a, b)
elif tail.startswith('||') and prec < 2:
b, tail = p_expr(tail[2:].lstrip(), 2)
a = RExpr.Oror(a, b)
# binary ops
if any(tail.startswith(op) and prec < RExpr.bprecs[op]
for op in RExpr.bops.keys()):
# sort by len to avoid ambiguities
for op in sorted(RExpr.bops.keys(), reverse=True):
if tail.startswith(op) and prec < RExpr.bprecs[op]:
b, tail = p_expr(
tail[len(op):].lstrip(),
RExpr.bprecs[op])
a = RExpr.bops[op](a, b)
break
else:
assert False
# ternary ops, this is intentionally right associative
elif tail.startswith('?') and prec <= 1:
b, tail = p_expr(tail[1:].lstrip(), 1)
if not tail.startswith(':'):
@@ -680,6 +703,7 @@ class RExpr:
c, tail = p_expr(tail[1:].lstrip(), 1)
a = RExpr.Ife(a, b, c)
# no tail
else:
return a, tail