Files
littlefs/scripts/prettyasserts.py
Christopher Haster a3ac512cc1 scripts: Adopted Parser class in prettyasserts.py
This ended up being a pretty in-depth rework of prettyasserts.py to
adopt the shared Parser class. But now prettyasserts.py should be both
more robust and faster.

The tricky parts:

- The Parser class eagerly munches whitespace by default. This is
  usually a good thing, but for prettyasserts.py we need to keep track
  of the whitespace somehow in order to write it to the output file.

  The solution here is a little bit hacky. Instead of complicating the
  Parser class, we implicitly add a regex group for whitespace when
  compiling our lexer.

  Unfortunately this does make last-minute patching of the lexer a bit
  messy (for things like -p/--prefix, etc), thanks to Python's
  re.Pattern class not being extendable. To work around this, the Lexer
  class keeps track of the original patterns to allow recompilation.

- Since we no longer tokenize in a separate pass, we can't use the
  None token to match any unmatched tokens.

  Fortunately this can be worked around with sufficiently ugly regex.
  See the 'STUFF' rule.

  It's a good thing Python has negative lookaheads.

  On the flip side, this means we no longer need to explicitly specify
  all possible tokens when multiple tokens overlap.

- Unlike stack.py/csv.py, prettyasserts.py needs multi-token lookahead.

  Fortunately this has a pretty straightforward solution with the
  addition of an optional stack to the Parser class.

  We can even have a bit of fun with Python's with statements (though I
  do wish with statements could have else clauses, so we wouldn't need
  double nesting to catch parser exceptions).

---

In addition to adopting the new Parser class, I also made sure to
eliminate intermediate string allocation through heavy use of Python's
io.StringIO class.

This, plus Parser's cheap shallow chomp/slice operations, gives
prettyasserts.py a much needed speed boost.

(Honestly, the original prettyasserts.py was pretty naive, with the
assumption that it wouldn't be the bottleneck during compilation. This
turned out to be wrong.)

These changes cut total compile time in ~half:

                                          real      user      sys
  before (time make test-runner -j): 0m56.202s 2m31.853s 0m2.827s
  after  (time make test-runner -j): 0m26.836s 1m51.213s 0m2.338s

Keep in mind this includes both prettyasserts.py and gcc -Os (and other
Makefile stuff).
2024-12-17 15:34:44 -06:00

670 lines
21 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Preprocessor that makes asserts easier to debug.
#
# Example:
# ./scripts/prettyasserts.py -p LFS_ASSERT lfs.c -o lfs.a.c
#
# Copyright (c) 2022, The littlefs authors.
# Copyright (c) 2020, Arm Limited. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# prevent local imports
__import__('sys').path.pop(0)
import io
import re
import sys
# default prettyassert limit
LIMIT = 16
# comparison ops
CMP = {
'==': 'eq',
'!=': 'ne',
'<=': 'le',
'>=': 'ge',
'<': 'lt',
'>': 'gt',
}
# helper class for lexical regexes
class Lexer:
def __init__(self):
self.patterns = {}
def lex(self, k, *patterns):
# compile with whitespace
l = re.compile(
'(?P<token>%s)(?P<ws>(?:%s)*)' % (
'|'.join(patterns) if patterns
# force a failure if we have no patterns
else '(?!)',
'|'.join(self.patterns.get('WS', ''))),
re.DOTALL)
# add to class members
setattr(self, k, l)
# keep track of patterns
self.patterns[k] = patterns
def extend(self, k, *patterns):
self.lex(k, *self.patterns.get(k, []), *patterns)
L = Lexer()
L.lex('WS', r'(?:\s|\n|#.*?(?<!\\)\n|//.*?(?<!\\)\n|/\*.*?\*/)+')
L.lex('ASSERT', r'\bassert\b', r'\b__builtin_assert\b')
L.lex('UNREACHABLE', r'\bunreachable\b', r'\b__builtin_unreachable\b')
L.lex('MEMCMP', r'\bmemcmp\b', r'\b__builtin_memcmp\b')
L.lex('STRCMP', r'\bstrcmp\b', r'\b__builtin_strcmp\b')
L.lex('ARROW', '=>')
L.lex('STR', r'"(?:\\.|[^"])*"', r"'(?:\\.|[^'])\'")
L.lex('LPAREN', '\(')
L.lex('RPAREN', '\)')
L.lex('ZERO', '\\b0\\b')
L.lex('CMP', *CMP.keys())
L.lex('LOGIC', '\&\&', '\|\|')
L.lex('TERN', '\?', ':')
L.lex('COMMA', ',')
L.lex('TERM', ';', '\{', '\}')
L.lex('STUFF', '[^;{}?:,()"\'=!<>\-&|/#]+',
# these need special handling because we're only
# using regex
'->', '>>', '<<', '-(?!>)',
'=(?![=>])', '!(?!=)', '&(?!&)', '\|(?!\|)',
'/(?!/)', '/(?!\*)')
def openio(path, mode='r', buffering=-1):
# allow '-' for stdin/stdout
if path == '-':
if 'r' in mode:
return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering)
else:
return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering)
else:
return open(path, mode, buffering)
def mkheader(f, limit=LIMIT):
f.writeln("// Generated by %s:" % sys.argv[0])
f.writeln("//")
f.writeln("// %s" % ' '.join(sys.argv))
f.writeln("//")
f.writeln()
f.writeln("#include <stdbool.h>")
f.writeln("#include <stdint.h>")
f.writeln("#include <inttypes.h>")
f.writeln("#include <stdio.h>")
f.writeln("#include <string.h>")
# give source a chance to define feature macros
f.writeln("#undef _FEATURES_H")
f.writeln()
# write print macros
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_bool(")
f.writeln(" const void *v, size_t size) {")
f.writeln(" (void)size;")
f.writeln(" printf(\"%s\", *(const bool*)v ? \"true\" : \"false\");")
f.writeln("}")
f.writeln()
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_int(")
f.writeln(" const void *v, size_t size) {")
f.writeln(" (void)size;")
f.writeln(" printf(\"%\"PRIiMAX, *(const intmax_t*)v);")
f.writeln("}")
f.writeln()
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_mem(")
f.writeln(" const void *v, size_t size) {")
f.writeln(" const uint8_t *v_ = v;")
f.writeln(" printf(\"\\\"\");")
f.writeln(" for (size_t i = 0; i < size && i < %d; i++) {" % limit)
f.writeln(" if (v_[i] >= ' ' && v_[i] <= '~') {")
f.writeln(" printf(\"%c\", v_[i]);")
f.writeln(" } else {")
f.writeln(" printf(\"\\\\x%02x\", v_[i]);")
f.writeln(" }")
f.writeln(" }")
f.writeln(" if (size > %d) {" % limit)
f.writeln(" printf(\"...\");")
f.writeln(" }")
f.writeln(" printf(\"\\\"\");")
f.writeln("}")
f.writeln()
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_str(")
f.writeln(" const void *v, size_t size) {")
f.writeln(" __pretty_assert_mem(v, size);")
f.writeln("}")
f.writeln()
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_print(")
f.writeln(" const char *file, int line,")
f.writeln(" void (*type_print_cb)(const void*, size_t),")
f.writeln(" const char *cmp,")
f.writeln(" const void *lh, size_t lsize,")
f.writeln(" const void *rh, size_t rsize) {")
f.writeln(" printf(\"%s:%d:assert: assert failed with \", file, line);")
f.writeln(" type_print_cb(lh, lsize);")
f.writeln(" printf(\", expected %s \", cmp);")
f.writeln(" type_print_cb(rh, rsize);")
f.writeln(" printf(\"\\n\");")
f.writeln(" fflush(NULL);")
f.writeln("}")
f.writeln()
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_print_unreachable(")
f.writeln(" const char *file, int line) {")
f.writeln(" printf(\"%s:%d:unreachable: \"")
f.writeln(" \"unreachable statement reached\\n\", file, line);")
f.writeln(" fflush(NULL);")
f.writeln("}")
f.writeln()
# write assert macros
for op, cmp in sorted(CMP.items()):
f.writeln("#define __PRETTY_ASSERT_BOOL_%s(lh, rh) do { \\" % (
cmp.upper()))
f.writeln(" bool _lh = !!(lh); \\")
f.writeln(" bool _rh = !!(rh); \\")
f.writeln(" if (!(_lh %s _rh)) { \\" % op)
f.writeln(" __pretty_assert_print( \\")
f.writeln(" __FILE__, __LINE__, \\")
f.writeln(" __pretty_assert_bool, \"%s\", \\" % cmp)
f.writeln(" &_lh, 0, \\")
f.writeln(" &_rh, 0); \\")
f.writeln(" __builtin_trap(); \\")
f.writeln(" } \\")
f.writeln("} while (0)")
f.writeln()
for op, cmp in sorted(CMP.items()):
f.writeln("#define __PRETTY_ASSERT_INT_%s(lh, rh) do { \\" % (
cmp.upper()))
f.writeln(" __typeof__(lh) _lh = lh; \\")
f.writeln(" __typeof__(lh) _rh = rh; \\")
f.writeln(" if (!(_lh %s _rh)) { \\" % op)
f.writeln(" __pretty_assert_print( \\")
f.writeln(" __FILE__, __LINE__, \\")
f.writeln(" __pretty_assert_int, \"%s\", \\" % cmp)
f.writeln(" &(intmax_t){(intmax_t)_lh}, 0, \\")
f.writeln(" &(intmax_t){(intmax_t)_rh}, 0); \\")
f.writeln(" __builtin_trap(); \\")
f.writeln(" } \\")
f.writeln("} while (0)")
f.writeln()
for op, cmp in sorted(CMP.items()):
f.writeln("#define __PRETTY_ASSERT_MEM_%s(lh, rh, size) do { \\" % (
cmp.upper()))
f.writeln(" const void *_lh = lh; \\")
f.writeln(" const void *_rh = rh; \\")
f.writeln(" if (!(memcmp(_lh, _rh, size) %s 0)) { \\" % op)
f.writeln(" __pretty_assert_print( \\")
f.writeln(" __FILE__, __LINE__, \\")
f.writeln(" __pretty_assert_mem, \"%s\", \\" % cmp)
f.writeln(" _lh, size, \\")
f.writeln(" _rh, size); \\")
f.writeln(" __builtin_trap(); \\")
f.writeln(" } \\")
f.writeln("} while (0)")
f.writeln()
for op, cmp in sorted(CMP.items()):
f.writeln("#define __PRETTY_ASSERT_STR_%s(lh, rh) do { \\" % (
cmp.upper()))
f.writeln(" const char *_lh = lh; \\")
f.writeln(" const char *_rh = rh; \\")
f.writeln(" if (!(strcmp(_lh, _rh) %s 0)) { \\" % op)
f.writeln(" __pretty_assert_print( \\")
f.writeln(" __FILE__, __LINE__, \\")
f.writeln(" __pretty_assert_str, \"%s\", \\" % cmp)
f.writeln(" _lh, strlen(_lh), \\")
f.writeln(" _rh, strlen(_rh)); \\")
f.writeln(" __builtin_trap(); \\")
f.writeln(" } \\")
f.writeln("} while (0)")
f.writeln()
f.writeln("#define __PRETTY_ASSERT_UNREACHABLE() do { \\")
f.writeln(" __pretty_assert_print_unreachable( \\")
f.writeln(" __FILE__, __LINE__); \\")
f.writeln(" __builtin_trap(); \\")
f.writeln("} while (0)")
f.writeln()
f.writeln()
def mkassert(f, type, cmp, lh, rh, size=None):
if size is not None:
f.write("__PRETTY_ASSERT_%s_%s(%s, %s, %s)" % (
type.upper(), cmp.upper(), lh, rh, size))
else:
f.write("__PRETTY_ASSERT_%s_%s(%s, %s)" % (
type.upper(), cmp.upper(), lh, rh))
def mkunreachable(f):
f.write("__PRETTY_ASSERT_UNREACHABLE()")
# a simple general-purpose parser class
#
# basically just because memoryview doesn't support strs
class Parser:
def __init__(self, data, ws='\s*', ws_flags=0):
self.data = data
self.i = 0
self.m = None
# also consume whitespace
self.ws = re.compile(ws, ws_flags)
self.i = self.ws.match(self.data, self.i).end()
def __repr__(self):
if len(self.data) - self.i <= 32:
return repr(self.data[self.i:])
else:
return "%s..." % repr(self.data[self.i:self.i+32])[:32]
def __str__(self):
return self.data[self.i:]
def __len__(self):
return len(self.data) - self.i
def __bool__(self):
return self.i != len(self.data)
def match(self, pattern, flags=0):
# compile so we can use the pos arg, this is still cached
self.m = re.compile(pattern, flags).match(self.data, self.i)
return self.m
def group(self, *groups):
return self.m.group(*groups)
def chomp(self, *groups):
g = self.group(*groups)
self.i = self.m.end()
# also consume whitespace
self.i = self.ws.match(self.data, self.i).end()
return g
class Error(Exception):
pass
def chompmatch(self, pattern, flags=0, *groups):
if not self.match(pattern, flags):
raise Parser.Error("expected %r, found %r" % (pattern, self))
return self.chomp(*groups)
def unexpected(self):
raise Parser.Error("unexpected %r" % self)
def lookahead(self):
# push state on the stack
if not hasattr(self, 'stack'):
self.stack = []
self.stack.append((self.i, self.m))
return self
def consume(self):
# pop and use new state
self.stack.pop()
def discard(self):
# pop and discard new state
self.i, self.m = self.stack.pop()
def __enter__(self):
return self
def __exit__(self, et, ev, tb):
# keep new state if no exception occured
if et is None:
self.consume()
else:
self.discard()
# parse rules
def p_assert(p, f):
# assert(memcmp(a,b,size) cmp 0)?
try:
with p.lookahead():
p.chompmatch(L.ASSERT)
p.chompmatch(L.LPAREN)
p.chompmatch(L.MEMCMP)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_expr(p, lh)
lh = lh.getvalue()
p.chompmatch(L.COMMA)
rh = io.StringIO()
p_expr(p, rh)
rh = rh.getvalue()
p.chompmatch(L.COMMA)
size = io.StringIO()
p_expr(p, size)
size = size.getvalue()
p.chompmatch(L.RPAREN)
cmp = p.chompmatch(L.CMP, 0, 'token')
p.chompmatch(L.ZERO)
ws = p.chompmatch(L.RPAREN, 0, 'ws')
mkassert(f, 'mem', CMP[cmp], lh, rh, size)
f.write(ws)
return
except Parser.Error:
pass
# assert(strcmp(a,b) cmp 0)?
try:
with p.lookahead():
p.chompmatch(L.ASSERT)
p.chompmatch(L.LPAREN)
p.chompmatch(L.STRCMP)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_expr(p, lh)
lh = lh.getvalue()
p.chompmatch(L.COMMA)
rh = io.StringIO()
p_expr(p, rh)
rh = rh.getvalue()
p.chompmatch(L.RPAREN)
cmp = p.chompmatch(L.CMP, 0, 'token')
p.chompmatch(L.ZERO)
ws = p.chompmatch(L.RPAREN, 0, 'ws')
mkassert(f, 'str', CMP[cmp], lh, rh)
f.write(ws)
return
except Parser.Error:
pass
# assert(a cmp b)?
try:
with p.lookahead():
p.chompmatch(L.ASSERT)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_simpleexpr(p, lh)
lh = lh.getvalue()
cmp = p.chompmatch(L.CMP, 0, 'token')
rh = io.StringIO()
p_simpleexpr(p, rh)
rh = rh.getvalue()
ws = p.chompmatch(L.RPAREN, 0, 'ws')
mkassert(f, 'int', CMP[cmp], lh, rh)
f.write(ws)
return
except Parser.Error:
pass
# assert(a)?
p.chompmatch(L.ASSERT)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_exprs(p, lh)
lh = lh.getvalue()
ws = p.chompmatch(L.RPAREN, 0, 'ws')
mkassert(f, 'bool', 'eq', lh, 'true')
f.write(ws)
def p_unreachable(p, f):
# unreachable()?
p.chompmatch(L.UNREACHABLE)
p.chompmatch(L.LPAREN)
ws = p.chompmatch(L.RPAREN, 0, 'ws')
mkunreachable(f)
f.write(ws)
def p_simpleexpr(p, f):
while True:
# parens
if p.match(L.LPAREN):
f.write(p.chomp())
# allow terms in parens
while True:
p_exprs(p, f)
if p.match(L.TERM):
f.write(p.chomp())
else:
break
f.write(p.chompmatch(L.RPAREN))
# asserts
elif p.match(L.ASSERT):
try:
with p.lookahead():
p_assert(p, f)
except Parser.Error:
f.write(p.chomp())
# unreachables
elif p.match(L.UNREACHABLE):
try:
with p.lookahead():
p_unreachable(p, f)
except Parser.Error:
f.write(p.chomp())
# anything else
elif p.match(L.STR) or p.match(L.STUFF):
f.write(p.chomp())
else:
break
def p_expr(p, f):
while True:
p_simpleexpr(p, f)
# continue if we hit a complex expr
if p.match(L.CMP) or p.match(L.LOGIC) or p.match(L.TERN):
f.write(p.chomp())
else:
break
def p_exprs(p, f):
while True:
p_expr(p, f)
# continue if we hit a comma
if p.match(L.COMMA):
f.write(p.chomp())
else:
break
def p_stmt(p, f):
# leading whitespace?
if p.match(L.WS):
f.write(p.chomp())
# memcmp(lh,rh,size) => 0?
if p.match(L.MEMCMP):
try:
with p.lookahead():
p.chompmatch(L.MEMCMP)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_expr(p, lh)
lh = lh.getvalue()
p.chompmatch(L.COMMA)
rh = io.StringIO()
p_expr(p, rh)
rh = rh.getvalue()
p.chompmatch(L.COMMA)
size = io.StringIO()
p_expr(p, size)
size = size.getvalue()
p.chompmatch(L.RPAREN)
p.chompmatch(L.ARROW)
ws = p.chompmatch(L.ZERO, 0, 'ws')
mkassert(f, 'mem', 'eq', lh, rh, size)
f.write(ws)
return
except Parse.Error:
pass
# strcmp(lh,rh) => 0?
if p.match(L.STRCMP):
try:
with p.lookahead():
p.chompmatch(L.STRCMP)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_expr(p, lh)
lh = lh.getvalue()
p.chompmatch(L.COMMA)
rh = io.StringIO()
p_expr(p, rh)
rh = rh.getvalue()
p.chompmatch(L.RPAREN)
p.chompmatch(L.ARROW)
ws = p.chompmatch(L.ZERO, 0, 'ws')
mkassert(f, 'str', 'eq', lh, rh)
f.write(ws)
return
except Parse.Error:
pass
# lh => rh?
lh = io.StringIO()
p_exprs(p, lh)
lh = lh.getvalue()
if p.match(L.ARROW):
p.chomp()
rh = io.StringIO()
p_exprs(p, rh)
rh = rh.getvalue()
mkassert(f, 'int', 'eq', lh, rh)
else:
f.write(lh)
def main(input=None, output=None, *,
prefix=[],
prefix_insensitive=[],
assert_=[],
unreachable=[],
memcmp=[],
strcmp=[],
no_defaults=False,
no_upper=False,
no_arrows=False,
limit=LIMIT):
# modify lexer rules?
if no_defaults:
L.lex('ASSERT', [])
L.lex('UNREACHABLE', [])
L.lex('MEMCMP', [])
L.lex('STRCMP', [])
for p in prefix + prefix_insensitive:
L.extend('ASSERT', r'\b%sassert\b' % p)
L.extend('UNREACHABLE', r'\b%sunreachable\b' % p)
L.extend('MEMCMP', r'\b%smemcmp\b' % p)
L.extend('STRCMP', r'\b%sstrcmp\b' % p)
for p in prefix_insensitive:
L.extend('ASSERT', r'\b%sassert\b' % p.lower())
L.extend('UNREACHABLE', r'\b%sunreachable\b' % p.lower())
L.extend('MEMCMP', r'\b%smemcmp\b' % p.lower())
L.extend('STRCMP', r'\b%sstrcmp\b' % p.lower())
L.extend('ASSERT', r'\b%sASSERT\b' % p.upper())
L.extend('UNREACHABLE', r'\b%sUNREACHABLE\b' % p.upper())
L.extend('MEMCMP', r'\b%sMEMCMP\b' % p.upper())
L.extend('STRCMP', r'\b%sSTRCMP\b' % p.upper())
if assert_:
L.extend('ASSERT', *[r'\b%s\b' % r for r in assert_])
if unreachable:
L.extend('UNREACHABLE', *[r'\b%s\b' % r for r in unreachable])
if memcmp:
L.extend('MEMCMP', *[r'\b%s\b' % r for r in memcmp])
if strcmp:
L.extend('STRCMP', *[r'\b%s\b' % r for r in strcmp])
# start parsing
with openio(input or '-', 'r') as in_f:
p = Parser(in_f.read(), '')
with openio(output or '-', 'w') as f:
def writeln(s=''):
f.write(s)
f.write('\n')
f.writeln = writeln
# write extra verbose asserts
mkheader(f, limit=limit)
if input is not None:
f.writeln("#line %d \"%s\"" % (1, input))
# parse and write out stmt at a time
try:
while True:
p_stmt(p, f)
if p.match(L.TERM):
f.write(p.chomp())
else:
break
# trailing junk?
if p:
p.unexpected()
except Parser.Error as e:
# warn on error
print('warning: %s' % e)
# still write out the rest of the file so compiler
# errors can be reported, these are usually more useful
f.write(str(p))
if __name__ == "__main__":
import argparse
import sys
parser = argparse.ArgumentParser(
description="Preprocessor that makes asserts easier to debug.",
allow_abbrev=False)
parser.add_argument(
'input',
help="Input C file.")
parser.add_argument(
'-o', '--output',
required=True,
help="Output C file.")
parser.add_argument(
'-p', '--prefix',
action='append',
help="Additional prefixes for symbols.")
parser.add_argument(
'-P', '--prefix-insensitive',
action='append',
help="Additional prefixes for lower/upper case symbol variants.")
parser.add_argument(
'--assert',
dest='assert_',
action='append',
help="Additional symbols for assert statements.")
parser.add_argument(
'--unreachable',
action='append',
help="Additional symbols for unreachable statements.")
parser.add_argument(
'--memcmp',
action='append',
help="Additional symbols for memcmp expressions.")
parser.add_argument(
'--strcmp',
action='append',
help="Additional symbols for strcmp expressions.")
parser.add_argument(
'-n', '--no-defaults',
action='store_true',
help="Disable default symbols.")
parser.add_argument(
'--no-arrows',
action='store_true',
help="Disable arrow (=>) expressions.")
parser.add_argument(
'-l', '--limit',
type=lambda x: int(x, 0),
default=LIMIT,
help="Maximum number of characters to display in strcmp and "
"memcmp. Defaults to %r." % LIMIT)
sys.exit(main(**{k: v
for k, v in vars(parser.parse_intermixed_args()).items()
if v is not None}))