#!/usr/bin/env python3 # # Preprocessor that makes asserts easier to debug. # # Example: # ./scripts/prettyasserts.py -p LFS_ASSERT lfs.c -o lfs.a.c # # Copyright (c) 2022, The littlefs authors. # Copyright (c) 2020, Arm Limited. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # prevent local imports if __name__ == "__main__": __import__('sys').path.pop(0) import io import re import sys # default prettyassert limit LIMIT = 16 # comparison ops CMP = { '==': 'eq', '!=': 'ne', '<=': 'le', '>=': 'ge', '<': 'lt', '>': 'gt', } # helper class for lexical regexes class Lexer: def __init__(self): self.patterns = {} def lex(self, k, *patterns): # compile with whitespace l = re.compile( '(?P%s)(?P(?:%s)*)' % ( '|'.join(patterns) if patterns # force a failure if we have no patterns else '(?!)', '|'.join(self.patterns.get('WS', ''))), re.DOTALL) # add to class members setattr(self, k, l) # keep track of patterns self.patterns[k] = patterns def extend(self, k, *patterns): self.lex(k, *self.patterns.get(k, []), *patterns) L = Lexer() L.lex('WS', r'(?:\s|\n|#.*?(?') L.lex('STR', r'"(?:\\.|[^"])*"', r"'(?:\\.|[^'])\'") L.lex('LPAREN', '\(') L.lex('RPAREN', '\)') L.lex('ZERO', '\\b0\\b') L.lex('CMP', *CMP.keys()) L.lex('LOGIC', '\&\&', '\|\|') L.lex('TERN', '\?', ':') L.lex('COMMA', ',') L.lex('TERM', ';', '\{', '\}') L.lex('STUFF', '[^;{}?:,()"\'=!<>\-&|/#]+', # these need special handling because we're only # using regex '->', '>>', '<<', '-(?!>)', '=(?![=>])', '!(?!=)', '&(?!&)', '\|(?!\|)', '/(?!/)', '/(?!\*)') def openio(path, mode='r', buffering=-1): # allow '-' for stdin/stdout if path == '-': if 'r' in mode: return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering) else: return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering) else: return open(path, mode, buffering) def mkheader(f, limit=LIMIT): f.writeln("// Generated by %s:" % sys.argv[0]) f.writeln("//") f.writeln("// %s" % ' '.join(sys.argv)) f.writeln("//") f.writeln() f.writeln("#include ") f.writeln("#include ") f.writeln("#include ") f.writeln("#include ") f.writeln("#include ") # give source a chance to define feature macros f.writeln("#undef _FEATURES_H") f.writeln() # write print macros f.writeln("__attribute__((unused))") f.writeln("static void __pretty_assert_bool(") f.writeln(" const void *v, size_t size) {") f.writeln(" (void)size;") f.writeln(" printf(\"%s\", *(const bool*)v ? \"true\" : \"false\");") f.writeln("}") f.writeln() f.writeln("__attribute__((unused))") f.writeln("static void __pretty_assert_int(") f.writeln(" const void *v, size_t size) {") f.writeln(" (void)size;") f.writeln(" printf(\"%\"PRIiMAX, *(const intmax_t*)v);") f.writeln("}") f.writeln() f.writeln("__attribute__((unused))") f.writeln("static void __pretty_assert_mem(") f.writeln(" const void *v, size_t size) {") f.writeln(" const uint8_t *v_ = v;") f.writeln(" printf(\"\\\"\");") f.writeln(" for (size_t i = 0; i < size && i < %d; i++) {" % limit) f.writeln(" if (v_[i] >= ' ' && v_[i] <= '~') {") f.writeln(" printf(\"%c\", v_[i]);") f.writeln(" } else {") f.writeln(" printf(\"\\\\x%02x\", v_[i]);") f.writeln(" }") f.writeln(" }") f.writeln(" if (size > %d) {" % limit) f.writeln(" printf(\"...\");") f.writeln(" }") f.writeln(" printf(\"\\\"\");") f.writeln("}") f.writeln() f.writeln("__attribute__((unused))") f.writeln("static void __pretty_assert_str(") f.writeln(" const void *v, size_t size) {") f.writeln(" __pretty_assert_mem(v, size);") f.writeln("}") f.writeln() f.writeln("__attribute__((unused))") f.writeln("static void __pretty_assert_print(") f.writeln(" const char *file, int line,") f.writeln(" void (*type_print_cb)(const void*, size_t),") f.writeln(" const char *cmp,") f.writeln(" const void *lh, size_t lsize,") f.writeln(" const void *rh, size_t rsize) {") f.writeln(" printf(\"%s:%d:assert: assert failed with \", file, line);") f.writeln(" type_print_cb(lh, lsize);") f.writeln(" printf(\", expected %s \", cmp);") f.writeln(" type_print_cb(rh, rsize);") f.writeln(" printf(\"\\n\");") f.writeln(" fflush(NULL);") f.writeln("}") f.writeln() f.writeln("__attribute__((unused))") f.writeln("static void __pretty_assert_print_unreachable(") f.writeln(" const char *file, int line) {") f.writeln(" printf(\"%s:%d:unreachable: \"") f.writeln(" \"unreachable statement reached\\n\", file, line);") f.writeln(" fflush(NULL);") f.writeln("}") f.writeln() # write assert macros for op, cmp in sorted(CMP.items()): f.writeln("#define __PRETTY_ASSERT_BOOL_%s(lh, rh) do { \\" % ( cmp.upper())) f.writeln(" bool _lh = !!(lh); \\") f.writeln(" bool _rh = !!(rh); \\") f.writeln(" if (!(_lh %s _rh)) { \\" % op) f.writeln(" __pretty_assert_print( \\") f.writeln(" __FILE__, __LINE__, \\") f.writeln(" __pretty_assert_bool, \"%s\", \\" % cmp) f.writeln(" &_lh, 0, \\") f.writeln(" &_rh, 0); \\") f.writeln(" __builtin_trap(); \\") f.writeln(" } \\") f.writeln("} while (0)") f.writeln() for op, cmp in sorted(CMP.items()): f.writeln("#define __PRETTY_ASSERT_INT_%s(lh, rh) do { \\" % ( cmp.upper())) f.writeln(" __typeof__(lh) _lh = lh; \\") f.writeln(" __typeof__(lh) _rh = rh; \\") f.writeln(" if (!(_lh %s _rh)) { \\" % op) f.writeln(" __pretty_assert_print( \\") f.writeln(" __FILE__, __LINE__, \\") f.writeln(" __pretty_assert_int, \"%s\", \\" % cmp) f.writeln(" &(intmax_t){(intmax_t)_lh}, 0, \\") f.writeln(" &(intmax_t){(intmax_t)_rh}, 0); \\") f.writeln(" __builtin_trap(); \\") f.writeln(" } \\") f.writeln("} while (0)") f.writeln() for op, cmp in sorted(CMP.items()): f.writeln("#define __PRETTY_ASSERT_MEM_%s(lh, rh, size) do { \\" % ( cmp.upper())) f.writeln(" const void *_lh = lh; \\") f.writeln(" const void *_rh = rh; \\") f.writeln(" if (!(memcmp(_lh, _rh, size) %s 0)) { \\" % op) f.writeln(" __pretty_assert_print( \\") f.writeln(" __FILE__, __LINE__, \\") f.writeln(" __pretty_assert_mem, \"%s\", \\" % cmp) f.writeln(" _lh, size, \\") f.writeln(" _rh, size); \\") f.writeln(" __builtin_trap(); \\") f.writeln(" } \\") f.writeln("} while (0)") f.writeln() for op, cmp in sorted(CMP.items()): f.writeln("#define __PRETTY_ASSERT_STR_%s(lh, rh) do { \\" % ( cmp.upper())) f.writeln(" const char *_lh = lh; \\") f.writeln(" const char *_rh = rh; \\") f.writeln(" if (!(strcmp(_lh, _rh) %s 0)) { \\" % op) f.writeln(" __pretty_assert_print( \\") f.writeln(" __FILE__, __LINE__, \\") f.writeln(" __pretty_assert_str, \"%s\", \\" % cmp) f.writeln(" _lh, strlen(_lh), \\") f.writeln(" _rh, strlen(_rh)); \\") f.writeln(" __builtin_trap(); \\") f.writeln(" } \\") f.writeln("} while (0)") f.writeln() f.writeln("#define __PRETTY_ASSERT_UNREACHABLE() do { \\") f.writeln(" __pretty_assert_print_unreachable( \\") f.writeln(" __FILE__, __LINE__); \\") f.writeln(" __builtin_trap(); \\") f.writeln("} while (0)") f.writeln() f.writeln() def mkassert(f, type, cmp, lh, rh, size=None): if size is not None: f.write("__PRETTY_ASSERT_%s_%s(%s, %s, %s)" % ( type.upper(), cmp.upper(), lh, rh, size)) else: f.write("__PRETTY_ASSERT_%s_%s(%s, %s)" % ( type.upper(), cmp.upper(), lh, rh)) def mkunreachable(f): f.write("__PRETTY_ASSERT_UNREACHABLE()") # a simple general-purpose parser class # # basically just because memoryview doesn't support strs class Parser: def __init__(self, data, ws='\s*', ws_flags=0): self.data = data self.i = 0 self.m = None # also consume whitespace self.ws = re.compile(ws, ws_flags) self.i = self.ws.match(self.data, self.i).end() def __repr__(self): if len(self.data) - self.i <= 32: return repr(self.data[self.i:]) else: return "%s..." % repr(self.data[self.i:self.i+32])[:32] def __str__(self): return self.data[self.i:] def __len__(self): return len(self.data) - self.i def __bool__(self): return self.i != len(self.data) def match(self, pattern, flags=0): # compile so we can use the pos arg, this is still cached self.m = re.compile(pattern, flags).match(self.data, self.i) return self.m def group(self, *groups): return self.m.group(*groups) def chomp(self, *groups): g = self.group(*groups) self.i = self.m.end() # also consume whitespace self.i = self.ws.match(self.data, self.i).end() return g class Error(Exception): pass def chompmatch(self, pattern, flags=0, *groups): if not self.match(pattern, flags): raise Parser.Error("expected %r, found %r" % (pattern, self)) return self.chomp(*groups) def unexpected(self): raise Parser.Error("unexpected %r" % self) def lookahead(self): # push state on the stack if not hasattr(self, 'stack'): self.stack = [] self.stack.append((self.i, self.m)) return self def consume(self): # pop and use new state self.stack.pop() def discard(self): # pop and discard new state self.i, self.m = self.stack.pop() def __enter__(self): return self def __exit__(self, et, ev, tb): # keep new state if no exception occured if et is None: self.consume() else: self.discard() # parse rules def p_assert(p, f): # assert(memcmp(a,b,size) cmp 0)? try: with p.lookahead(): p.chompmatch(L.ASSERT) p.chompmatch(L.LPAREN) p.chompmatch(L.MEMCMP) p.chompmatch(L.LPAREN) lh = io.StringIO() p_expr(p, lh) lh = lh.getvalue() p.chompmatch(L.COMMA) rh = io.StringIO() p_expr(p, rh) rh = rh.getvalue() p.chompmatch(L.COMMA) size = io.StringIO() p_expr(p, size) size = size.getvalue() p.chompmatch(L.RPAREN) cmp = p.chompmatch(L.CMP, 0, 'token') p.chompmatch(L.ZERO) ws = p.chompmatch(L.RPAREN, 0, 'ws') mkassert(f, 'mem', CMP[cmp], lh, rh, size) f.write(ws) return except Parser.Error: pass # assert(strcmp(a,b) cmp 0)? try: with p.lookahead(): p.chompmatch(L.ASSERT) p.chompmatch(L.LPAREN) p.chompmatch(L.STRCMP) p.chompmatch(L.LPAREN) lh = io.StringIO() p_expr(p, lh) lh = lh.getvalue() p.chompmatch(L.COMMA) rh = io.StringIO() p_expr(p, rh) rh = rh.getvalue() p.chompmatch(L.RPAREN) cmp = p.chompmatch(L.CMP, 0, 'token') p.chompmatch(L.ZERO) ws = p.chompmatch(L.RPAREN, 0, 'ws') mkassert(f, 'str', CMP[cmp], lh, rh) f.write(ws) return except Parser.Error: pass # assert(a cmp b)? try: with p.lookahead(): p.chompmatch(L.ASSERT) p.chompmatch(L.LPAREN) lh = io.StringIO() p_simpleexpr(p, lh) lh = lh.getvalue() cmp = p.chompmatch(L.CMP, 0, 'token') rh = io.StringIO() p_simpleexpr(p, rh) rh = rh.getvalue() ws = p.chompmatch(L.RPAREN, 0, 'ws') mkassert(f, 'int', CMP[cmp], lh, rh) f.write(ws) return except Parser.Error: pass # assert(a)? p.chompmatch(L.ASSERT) p.chompmatch(L.LPAREN) lh = io.StringIO() p_exprs(p, lh) lh = lh.getvalue() ws = p.chompmatch(L.RPAREN, 0, 'ws') mkassert(f, 'bool', 'eq', lh, 'true') f.write(ws) def p_unreachable(p, f): # unreachable()? p.chompmatch(L.UNREACHABLE) p.chompmatch(L.LPAREN) ws = p.chompmatch(L.RPAREN, 0, 'ws') mkunreachable(f) f.write(ws) def p_simpleexpr(p, f): while True: # parens if p.match(L.LPAREN): f.write(p.chomp()) # allow terms in parens while True: p_exprs(p, f) if p.match(L.TERM): f.write(p.chomp()) else: break f.write(p.chompmatch(L.RPAREN)) # asserts elif p.match(L.ASSERT): try: with p.lookahead(): p_assert(p, f) except Parser.Error: f.write(p.chomp()) # unreachables elif p.match(L.UNREACHABLE): try: with p.lookahead(): p_unreachable(p, f) except Parser.Error: f.write(p.chomp()) # anything else elif p.match(L.STR) or p.match(L.STUFF): f.write(p.chomp()) else: break def p_expr(p, f): while True: p_simpleexpr(p, f) # continue if we hit a complex expr if p.match(L.CMP) or p.match(L.LOGIC) or p.match(L.TERN): f.write(p.chomp()) else: break def p_exprs(p, f): while True: p_expr(p, f) # continue if we hit a comma if p.match(L.COMMA): f.write(p.chomp()) else: break def p_stmt(p, f): # leading whitespace? if p.match(L.WS): f.write(p.chomp()) # memcmp(lh,rh,size) => 0? if p.match(L.MEMCMP): try: with p.lookahead(): p.chompmatch(L.MEMCMP) p.chompmatch(L.LPAREN) lh = io.StringIO() p_expr(p, lh) lh = lh.getvalue() p.chompmatch(L.COMMA) rh = io.StringIO() p_expr(p, rh) rh = rh.getvalue() p.chompmatch(L.COMMA) size = io.StringIO() p_expr(p, size) size = size.getvalue() p.chompmatch(L.RPAREN) p.chompmatch(L.ARROW) ws = p.chompmatch(L.ZERO, 0, 'ws') mkassert(f, 'mem', 'eq', lh, rh, size) f.write(ws) return except Parse.Error: pass # strcmp(lh,rh) => 0? if p.match(L.STRCMP): try: with p.lookahead(): p.chompmatch(L.STRCMP) p.chompmatch(L.LPAREN) lh = io.StringIO() p_expr(p, lh) lh = lh.getvalue() p.chompmatch(L.COMMA) rh = io.StringIO() p_expr(p, rh) rh = rh.getvalue() p.chompmatch(L.RPAREN) p.chompmatch(L.ARROW) ws = p.chompmatch(L.ZERO, 0, 'ws') mkassert(f, 'str', 'eq', lh, rh) f.write(ws) return except Parse.Error: pass # lh => rh? lh = io.StringIO() p_exprs(p, lh) lh = lh.getvalue() if p.match(L.ARROW): p.chomp() rh = io.StringIO() p_exprs(p, rh) rh = rh.getvalue() mkassert(f, 'int', 'eq', lh, rh) else: f.write(lh) def main(input=None, output=None, *, prefix=[], prefix_insensitive=[], assert_=[], unreachable=[], memcmp=[], strcmp=[], no_defaults=False, no_upper=False, no_arrows=False, limit=LIMIT): # modify lexer rules? if no_defaults: L.lex('ASSERT', []) L.lex('UNREACHABLE', []) L.lex('MEMCMP', []) L.lex('STRCMP', []) for p in prefix + prefix_insensitive: L.extend('ASSERT', r'\b%sassert\b' % p) L.extend('UNREACHABLE', r'\b%sunreachable\b' % p) L.extend('MEMCMP', r'\b%smemcmp\b' % p) L.extend('STRCMP', r'\b%sstrcmp\b' % p) for p in prefix_insensitive: L.extend('ASSERT', r'\b%sassert\b' % p.lower()) L.extend('UNREACHABLE', r'\b%sunreachable\b' % p.lower()) L.extend('MEMCMP', r'\b%smemcmp\b' % p.lower()) L.extend('STRCMP', r'\b%sstrcmp\b' % p.lower()) L.extend('ASSERT', r'\b%sASSERT\b' % p.upper()) L.extend('UNREACHABLE', r'\b%sUNREACHABLE\b' % p.upper()) L.extend('MEMCMP', r'\b%sMEMCMP\b' % p.upper()) L.extend('STRCMP', r'\b%sSTRCMP\b' % p.upper()) if assert_: L.extend('ASSERT', *[r'\b%s\b' % r for r in assert_]) if unreachable: L.extend('UNREACHABLE', *[r'\b%s\b' % r for r in unreachable]) if memcmp: L.extend('MEMCMP', *[r'\b%s\b' % r for r in memcmp]) if strcmp: L.extend('STRCMP', *[r'\b%s\b' % r for r in strcmp]) # start parsing with openio(input or '-', 'r') as in_f: p = Parser(in_f.read(), '') with openio(output or '-', 'w') as f: def writeln(s=''): f.write(s) f.write('\n') f.writeln = writeln # write extra verbose asserts mkheader(f, limit=limit) if input is not None: f.writeln("#line %d \"%s\"" % (1, input)) # parse and write out stmt at a time try: while True: p_stmt(p, f) if p.match(L.TERM): f.write(p.chomp()) else: break # trailing junk? if p: p.unexpected() except Parser.Error as e: # warn on error print('warning: %s' % e) # still write out the rest of the file so compiler # errors can be reported, these are usually more useful f.write(str(p)) if __name__ == "__main__": import argparse import sys parser = argparse.ArgumentParser( description="Preprocessor that makes asserts easier to debug.", allow_abbrev=False) parser.add_argument( 'input', help="Input C file.") parser.add_argument( '-o', '--output', required=True, help="Output C file.") parser.add_argument( '-p', '--prefix', action='append', help="Additional prefixes for symbols.") parser.add_argument( '-P', '--prefix-insensitive', action='append', help="Additional prefixes for lower/upper case symbol variants.") parser.add_argument( '--assert', dest='assert_', action='append', help="Additional symbols for assert statements.") parser.add_argument( '--unreachable', action='append', help="Additional symbols for unreachable statements.") parser.add_argument( '--memcmp', action='append', help="Additional symbols for memcmp expressions.") parser.add_argument( '--strcmp', action='append', help="Additional symbols for strcmp expressions.") parser.add_argument( '-n', '--no-defaults', action='store_true', help="Disable default symbols.") parser.add_argument( '--no-arrows', action='store_true', help="Disable arrow (=>) expressions.") parser.add_argument( '-l', '--limit', type=lambda x: int(x, 0), default=LIMIT, help="Maximum number of characters to display in strcmp and " "memcmp. Defaults to %r." % LIMIT) sys.exit(main(**{k: v for k, v in vars(parser.parse_intermixed_args()).items() if v is not None}))