scripts: csv.py: RExpr decorators to help simplify func/uop/bop parsing

This was more tricky than expected since Python's class scope is so
funky (I just eneded up with using lazy cached __get__ functions that
scan the RExpr class for tagged members), but these decorators help avoid
repeated boilerplate for common expr patterns.

We can even deduplicate binary expr parsing without sacrificing
precedence.
This commit is contained in:
Christopher Haster
2024-11-11 16:14:13 -06:00
parent 4061891a02
commit ac0aa3633e

View File

@@ -272,7 +272,7 @@ class RExpr:
def __init__(self, reason): def __init__(self, reason):
self.reason = reason self.reason = reason
# expr nodes # expr node base class
class Expr: class Expr:
def __init__(self, *args): def __init__(self, *args):
for k, v in zip('abcdefghijklmnopqrstuvwxyz', args): for k, v in zip('abcdefghijklmnopqrstuvwxyz', args):
@@ -303,6 +303,9 @@ class RExpr:
def eval(self, fields={}): def eval(self, fields={}):
return self.a.eval(fields) return self.a.eval(fields)
# expr nodes
# field expr
class Field(Expr): class Field(Expr):
def fields(self): def fields(self):
return {self.a} return {self.a}
@@ -322,6 +325,7 @@ class RExpr:
raise RExpr.Error("unknown field? %s" % self.a) raise RExpr.Error("unknown field? %s" % self.a)
return fields[self.a] return fields[self.a]
# literal exprs
class StrLit(Expr): class StrLit(Expr):
def fields(self): def fields(self):
return set() return set()
@@ -355,6 +359,23 @@ class RExpr:
def eval(self, fields={}): def eval(self, fields={}):
return self.a return self.a
# func expr helper
def func(name):
def func(f):
f._func = name
return f
return func
class Funcs:
@ft.cache
def __get__(self, _, cls):
return {x._func: x
for x in cls.__dict__.values()
if hasattr(x, '_func')}
funcs = Funcs()
# type exprs
@func('int')
class Int(Expr): class Int(Expr):
def type(self, types={}): def type(self, types={}):
return RInt return RInt
@@ -362,6 +383,7 @@ class RExpr:
def eval(self, fields={}): def eval(self, fields={}):
return RInt(self.a.eval(fields)) return RInt(self.a.eval(fields))
@func('float')
class Float(Expr): class Float(Expr):
def type(self, types={}): def type(self, types={}):
return RFloat return RFloat
@@ -369,6 +391,7 @@ class RExpr:
def eval(self, fields={}): def eval(self, fields={}):
return RFloat(self.a.eval(fields)) return RFloat(self.a.eval(fields))
@func('frac')
class Frac(Expr): class Frac(Expr):
def type(self, types={}): def type(self, types={}):
return RFrac return RFrac
@@ -376,151 +399,206 @@ class RExpr:
def eval(self, fields={}): def eval(self, fields={}):
return RFrac(self.a.eval(fields), self.b.eval(fields)) return RFrac(self.a.eval(fields), self.b.eval(fields))
# fold exprs
@func('sum')
class Sum(Expr): class Sum(Expr):
def fold(self, types={}): def fold(self, types={}):
return RSum, self.a.type(types) return RSum, self.a.type(types)
@func('prod')
class Prod(Expr): class Prod(Expr):
def fold(self, types={}): def fold(self, types={}):
return RProd, self.a.type(types) return RProd, self.a.type(types)
@func('min')
class Min(Expr): class Min(Expr):
def fold(self, types={}): def fold(self, types={}):
return RMin, self.a.type(types) return RMin, self.a.type(types)
@func('max')
class Max(Expr): class Max(Expr):
def fold(self, types={}): def fold(self, types={}):
return RMax, self.a.type(types) return RMax, self.a.type(types)
@func('avg')
class Avg(Expr): class Avg(Expr):
def fold(self, types={}): def fold(self, types={}):
return RAvg, RFloat return RAvg, RFloat
@func('stddev')
class Stddev(Expr): class Stddev(Expr):
def fold(self, types={}): def fold(self, types={}):
return RStddev, RFloat return RStddev, RFloat
@func('gmean')
class GMean(Expr): class GMean(Expr):
def fold(self, types={}): def fold(self, types={}):
return RGMean, RFloat return RGMean, RFloat
@func('stddev')
class GStddev(Expr): class GStddev(Expr):
def fold(self, types={}): def fold(self, types={}):
return RGStddev, RFloat return RGStddev, RFloat
# functions
@func('ratio')
class Ratio(Expr): class Ratio(Expr):
pass pass
@func('total')
class Total(Expr): class Total(Expr):
pass pass
@func('ceil')
class Ceil(Expr): class Ceil(Expr):
pass pass
@func('floor')
class Floor(Expr): class Floor(Expr):
pass pass
@func('log')
class Log(Expr): class Log(Expr):
pass pass
@func('pow')
class Pow(Expr): class Pow(Expr):
pass pass
@func('sqrt')
class Sqrt(Expr): class Sqrt(Expr):
pass pass
funcs = { # unary expr helper
# types def uop(op):
'int': Int, def uop(f):
'float': Float, f._uop = op
'frac': Frac, return f
return uop
# functions class UOps:
'ratio': Ratio, @ft.cache
'total': Total, def __get__(self, _, cls):
'ceil': Ceil, return {x._uop: x
'floor': Floor, for x in cls.__dict__.values()
'log': Log, if hasattr(x, '_uop')}
'pow': Pow, uops = UOps()
'sqrt': Sqrt,
# mergers
'sum': Sum,
'prod': Prod,
'min': Min,
'max': Max,
'avg': Avg,
'stddev': Stddev,
'gmean': GMean,
'gstddev': GStddev,
}
# unary ops
@uop('+')
class Pos(Expr): class Pos(Expr):
pass pass
@uop('-')
class Neg(Expr): class Neg(Expr):
pass pass
@uop('~')
class Not(Expr): class Not(Expr):
pass pass
@uop('!')
class Notnot(Expr): class Notnot(Expr):
pass pass
# binary expr help
def bop(op, prec):
def bop(f):
f._bop = op
f._bprec = prec
return f
return bop
class BOps:
@ft.cache
def __get__(self, _, cls):
return {x._bop: x
for x in cls.__dict__.values()
if hasattr(x, '_bop')}
bops = BOps()
class BPrecs:
@ft.cache
def __get__(self, _, cls):
return {x._bop: x._bprec
for x in cls.__dict__.values()
if hasattr(x, '_bop')}
bprecs = BPrecs()
# binary ops
@bop('*', 10)
class Mul(Expr): class Mul(Expr):
pass pass
@bop('/', 10)
class Div(Expr): class Div(Expr):
pass pass
@bop('%', 10)
class Mod(Expr): class Mod(Expr):
pass pass
@bop('+', 9)
class Add(Expr): class Add(Expr):
pass pass
@bop('-', 9)
class Sub(Expr): class Sub(Expr):
pass pass
@bop('<<', 8)
class Shl(Expr): class Shl(Expr):
pass pass
@bop('>>', 8)
class Shr(Expr): class Shr(Expr):
pass pass
@bop('&', 7)
class And(Expr): class And(Expr):
pass pass
@bop('^', 6)
class Xor(Expr): class Xor(Expr):
pass pass
@bop('|', 5)
class Or(Expr): class Or(Expr):
pass pass
class Lt(Expr): @bop('==', 4)
pass
class Le(Expr):
pass
class Gt(Expr):
pass
class Ge(Expr):
pass
class Ne(Expr):
pass
class Eq(Expr): class Eq(Expr):
pass pass
@bop('!=', 4)
class Ne(Expr):
pass
@bop('<', 4)
class Lt(Expr):
pass
@bop('<=', 4)
class Le(Expr):
pass
@bop('>', 4)
class Gt(Expr):
pass
@bop('>=', 4)
class Ge(Expr):
pass
@bop('&&', 3)
class Andand(Expr): class Andand(Expr):
pass pass
@bop('||', 2)
class Oror(Expr): class Oror(Expr):
pass pass
# ternary ops
class Ife(Expr): class Ife(Expr):
def type(self, types={}): def type(self, types={}):
return self.b.type(types) return self.b.type(types)
@@ -528,18 +606,20 @@ class RExpr:
def fold(self, types={}): def fold(self, types={}):
return self.b.fold(types) return self.b.fold(types)
# parse and expr # parse an expr
def __init__(self, expr): def __init__(self, expr):
self.expr = expr.strip() self.expr = expr.strip()
# parse the expression into a tree # parse the expression into a tree
def p_expr(expr, prec=0): def p_expr(expr, prec=0):
# parens
if expr.startswith('('): if expr.startswith('('):
a, tail = p_expr(expr[1:].lstrip()) a, tail = p_expr(expr[1:].lstrip())
if not tail.startswith(')'): if not tail.startswith(')'):
raise RExpr.Error("mismatched parens? %s" % tail) raise RExpr.Error("mismatched parens? %s" % tail)
tail = tail[1:].lstrip() tail = tail[1:].lstrip()
# fields/functions
elif re.match('[_a-zA-Z][_a-zA-Z0-9]*', expr): elif re.match('[_a-zA-Z][_a-zA-Z0-9]*', expr):
m = re.match('[_a-zA-Z][_a-zA-Z0-9]*', expr) m = re.match('[_a-zA-Z][_a-zA-Z0-9]*', expr)
tail = expr[len(m.group()):].lstrip() tail = expr[len(m.group()):].lstrip()
@@ -566,113 +646,56 @@ class RExpr:
else: else:
a = RExpr.Field(m.group()) a = RExpr.Field(m.group())
# strings
elif re.match('(?:"(?:\\.|[^"])*"|\'(?:\\.|[^\'])\')', expr): elif re.match('(?:"(?:\\.|[^"])*"|\'(?:\\.|[^\'])\')', expr):
m = re.match('(?:"(?:\\.|[^"])*"|\'(?:\\.|[^\'])\')', expr) m = re.match('(?:"(?:\\.|[^"])*"|\'(?:\\.|[^\'])\')', expr)
a = RExpr.StrLit(m.group()[1:-1]) a = RExpr.StrLit(m.group()[1:-1])
tail = expr[len(m.group()):].lstrip() tail = expr[len(m.group()):].lstrip()
# floats
elif re.match('[+-]?[_0-9]*\.[_0-9eE]', expr): elif re.match('[+-]?[_0-9]*\.[_0-9eE]', expr):
m = re.match('[+-]?[_0-9]*\.[_0-9eE]', expr) m = re.match('[+-]?[_0-9]*\.[_0-9eE]', expr)
a = RExpr.FloatLit(RFloat(m.group())) a = RExpr.FloatLit(RFloat(m.group()))
tail = expr[len(m.group()):].lstrip() tail = expr[len(m.group()):].lstrip()
# ints
elif re.match('[+-]?(?:(?:0[bBoOxX])?[_0-9a-fA-F]+|∞|inf)', expr): elif re.match('[+-]?(?:(?:0[bBoOxX])?[_0-9a-fA-F]+|∞|inf)', expr):
m = re.match('[+-]?(?:(?:0[bBoOxX])?[_0-9a-fA-F]+|∞|inf)', expr) m = re.match('[+-]?(?:(?:0[bBoOxX])?[_0-9a-fA-F]+|∞|inf)', expr)
a = RExpr.IntLit(RInt(m.group())) a = RExpr.IntLit(RInt(m.group()))
tail = expr[len(m.group()):].lstrip() tail = expr[len(m.group()):].lstrip()
elif expr.startswith('+'): # unary ops
a, tail = p_expr(expr[1:].lstrip(), 12) elif any(expr.startswith(op) for op in RExpr.uops.keys()):
a = RExpr.Pos(a) # sort by len to avoid ambiguities
for op in sorted(RExpr.uops.keys(), reverse=True):
elif expr.startswith('-'): if expr.startswith(op):
a, tail = p_expr(expr[1:].lstrip(), 12) a, tail = p_expr(expr[len(op):].lstrip(), mt.inf())
a = RExpr.Neg(a) a = RExpr.uops[op](a)
break
elif expr.startswith('~'): else:
a, tail = p_expr(expr[1:].lstrip(), 12) assert False
a = RExpr.Not(a)
elif expr.startswith('!'):
a, tail = p_expr(expr[1:].lstrip(), 4)
a = RExpr.Notnot(a)
# unknown expr?
else: else:
raise RExpr.Error("unknown expr? %s" % expr) raise RExpr.Error("unknown expr? %s" % expr)
# parse tail
while True: while True:
if tail.startswith('*') and prec < 11: # binary ops
b, tail = p_expr(tail[1:].lstrip(), 11) if any(tail.startswith(op) and prec < RExpr.bprecs[op]
a = RExpr.Mul(a, b) for op in RExpr.bops.keys()):
# sort by len to avoid ambiguities
elif tail.startswith('/') and prec < 11: for op in sorted(RExpr.bops.keys(), reverse=True):
b, tail = p_expr(tail[1:].lstrip(), 11) if tail.startswith(op) and prec < RExpr.bprecs[op]:
a = RExpr.Div(a, b) b, tail = p_expr(
tail[len(op):].lstrip(),
elif tail.startswith('%') and prec < 11: RExpr.bprecs[op])
b, tail = p_expr(tail[1:].lstrip(), 11) a = RExpr.bops[op](a, b)
a = RExpr.Mod(a, b) break
else:
elif tail.startswith('+') and prec < 10: assert False
b, tail = p_expr(tail[1:].lstrip(), 10)
a = RExpr.Add(a, b)
elif tail.startswith('-') and prec < 10:
b, tail = p_expr(tail[1:].lstrip(), 10)
a = RExpr.Sub(a, b)
elif tail.startswith('<<') and prec < 9:
b, tail = p_expr(tail[2:].lstrip(), 9)
a = RExpr.Shl(a, b)
elif tail.startswith('>>') and prec < 9:
b, tail = p_expr(tail[2:].lstrip(), 9)
a = RExpr.Shr(a, b)
elif tail.startswith('&') and prec < 8:
b, tail = p_expr(tail[1:].lstrip(), 8)
a = RExpr.And(a, b)
elif tail.startswith('^') and prec < 7:
b, tail = p_expr(tail[1:].lstrip(), 7)
a = RExpr.Xor(a, b)
elif tail.startswith('|') and prec < 6:
b, tail = p_expr(tail[1:].lstrip(), 6)
a = RExpr.Or(a, b)
elif tail.startswith('<') and prec < 5:
b, tail = p_expr(tail[1:].lstrip(), 5)
a = RExpr.Lt(a, b)
elif tail.startswith('<=') and prec < 5:
b, tail = p_expr(tail[2:].lstrip(), 5)
a = RExpr.Le(a, b)
elif tail.startswith('>') and prec < 5:
b, tail = p_expr(tail[1:].lstrip(), 5)
a = RExpr.Gt(a, b)
elif tail.startswith('>=') and prec < 5:
b, tail = p_expr(tail[2:].lstrip(), 5)
a = RExpr.Ge(a, b)
elif tail.startswith('!=') and prec < 5:
b, tail = p_expr(tail[2:].lstrip(), 5)
a = RExpr.Ne(a, b)
elif tail.startswith('==') and prec < 5:
b, tail = p_expr(tail[2:].lstrip(), 5)
a = RExpr.Eq(a, b)
elif tail.startswith('&&') and prec < 3:
b, tail = p_expr(tail[2:].lstrip(), 3)
a = RExpr.Andand(a, b)
elif tail.startswith('||') and prec < 2:
b, tail = p_expr(tail[2:].lstrip(), 2)
a = RExpr.Oror(a, b)
# ternary ops, this is intentionally right associative
elif tail.startswith('?') and prec <= 1: elif tail.startswith('?') and prec <= 1:
b, tail = p_expr(tail[1:].lstrip(), 1) b, tail = p_expr(tail[1:].lstrip(), 1)
if not tail.startswith(':'): if not tail.startswith(':'):
@@ -680,6 +703,7 @@ class RExpr:
c, tail = p_expr(tail[1:].lstrip(), 1) c, tail = p_expr(tail[1:].lstrip(), 1)
a = RExpr.Ife(a, b, c) a = RExpr.Ife(a, b, c)
# no tail
else: else:
return a, tail return a, tail