Files
littlefs/scripts/prettyasserts.py
Christopher Haster 71930a5c01 scripts: Tweaked openio comment
Dang, this touched like every single script.
2025-04-16 15:23:06 -05:00

672 lines
22 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Preprocessor that makes asserts easier to debug.
#
# Example:
# ./scripts/prettyasserts.py -p LFS_ASSERT lfs.c -o lfs.a.c
#
# Copyright (c) 2022, The littlefs authors.
# Copyright (c) 2020, Arm Limited. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# prevent local imports
if __name__ == "__main__":
__import__('sys').path.pop(0)
import io
import re
import sys
# default prettyassert limit
LIMIT = 16
# comparison ops
CMP = {
'==': 'eq',
'!=': 'ne',
'<=': 'le',
'>=': 'ge',
'<': 'lt',
'>': 'gt',
}
# helper class for lexical regexes
class Lexer:
def __init__(self):
self.patterns = {}
def lex(self, k, *patterns):
# compile with whitespace
l = re.compile(
'(?P<token>%s)(?P<ws>(?:%s)*)' % (
'|'.join(patterns) if patterns
# force a failure if we have no patterns
else '(?!)',
'|'.join(self.patterns.get('WS', ''))),
re.DOTALL)
# add to class members
setattr(self, k, l)
# keep track of patterns
self.patterns[k] = patterns
def extend(self, k, *patterns):
self.lex(k, *self.patterns.get(k, []), *patterns)
L = Lexer()
L.lex('WS', r'(?:\s|\n|#.*?(?<!\\)\n|//.*?(?<!\\)\n|/\*.*?\*/)+')
L.lex('ASSERT', r'\bassert\b', r'\b__builtin_assert\b')
L.lex('UNREACHABLE', r'\bunreachable\b', r'\b__builtin_unreachable\b')
L.lex('MEMCMP', r'\bmemcmp\b', r'\b__builtin_memcmp\b')
L.lex('STRCMP', r'\bstrcmp\b', r'\b__builtin_strcmp\b')
L.lex('ARROW', '=>')
L.lex('STR', r'"(?:\\.|[^"])*"', r"'(?:\\.|[^'])\'")
L.lex('LPAREN', '\(')
L.lex('RPAREN', '\)')
L.lex('ZERO', '\\b0\\b')
L.lex('CMP', *CMP.keys())
L.lex('LOGIC', '\&\&', '\|\|')
L.lex('TERN', '\?', ':')
L.lex('COMMA', ',')
L.lex('TERM', ';', '\{', '\}')
L.lex('STUFF', '[^;{}?:,()"\'=!<>\-&|/#]+',
# these need special handling because we're only
# using regex
'->', '>>', '<<', '-(?!>)',
'=(?![=>])', '!(?!=)', '&(?!&)', '\|(?!\|)',
'/(?!/)', '/(?!\*)')
# open with '-' for stdin/stdout
def openio(path, mode='r', buffering=-1):
import os
if path == '-':
if 'r' in mode:
return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering)
else:
return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering)
else:
return open(path, mode, buffering)
def mkheader(f, limit=LIMIT):
f.writeln("// Generated by %s:" % sys.argv[0])
f.writeln("//")
f.writeln("// %s" % ' '.join(sys.argv))
f.writeln("//")
f.writeln()
f.writeln("#include <stdbool.h>")
f.writeln("#include <stdint.h>")
f.writeln("#include <inttypes.h>")
f.writeln("#include <stdio.h>")
f.writeln("#include <string.h>")
# give source a chance to define feature macros
f.writeln("#undef _FEATURES_H")
f.writeln()
# write print macros
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_bool(")
f.writeln(" const void *v, size_t size) {")
f.writeln(" (void)size;")
f.writeln(" printf(\"%s\", *(const bool*)v ? \"true\" : \"false\");")
f.writeln("}")
f.writeln()
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_int(")
f.writeln(" const void *v, size_t size) {")
f.writeln(" (void)size;")
f.writeln(" printf(\"%\"PRIiMAX, *(const intmax_t*)v);")
f.writeln("}")
f.writeln()
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_mem(")
f.writeln(" const void *v, size_t size) {")
f.writeln(" const uint8_t *v_ = v;")
f.writeln(" printf(\"\\\"\");")
f.writeln(" for (size_t i = 0; i < size && i < %d; i++) {" % limit)
f.writeln(" if (v_[i] >= ' ' && v_[i] <= '~') {")
f.writeln(" printf(\"%c\", v_[i]);")
f.writeln(" } else {")
f.writeln(" printf(\"\\\\x%02x\", v_[i]);")
f.writeln(" }")
f.writeln(" }")
f.writeln(" if (size > %d) {" % limit)
f.writeln(" printf(\"...\");")
f.writeln(" }")
f.writeln(" printf(\"\\\"\");")
f.writeln("}")
f.writeln()
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_str(")
f.writeln(" const void *v, size_t size) {")
f.writeln(" __pretty_assert_mem(v, size);")
f.writeln("}")
f.writeln()
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_print(")
f.writeln(" const char *file, int line,")
f.writeln(" void (*type_print_cb)(const void*, size_t),")
f.writeln(" const char *cmp,")
f.writeln(" const void *lh, size_t lsize,")
f.writeln(" const void *rh, size_t rsize) {")
f.writeln(" printf(\"%s:%d:assert: assert failed with \", file, line);")
f.writeln(" type_print_cb(lh, lsize);")
f.writeln(" printf(\", expected %s \", cmp);")
f.writeln(" type_print_cb(rh, rsize);")
f.writeln(" printf(\"\\n\");")
f.writeln(" fflush(NULL);")
f.writeln("}")
f.writeln()
f.writeln("__attribute__((unused))")
f.writeln("static void __pretty_assert_print_unreachable(")
f.writeln(" const char *file, int line) {")
f.writeln(" printf(\"%s:%d:unreachable: \"")
f.writeln(" \"unreachable statement reached\\n\", file, line);")
f.writeln(" fflush(NULL);")
f.writeln("}")
f.writeln()
# write assert macros
for op, cmp in sorted(CMP.items()):
f.writeln("#define __PRETTY_ASSERT_BOOL_%s(lh, rh) do { \\" % (
cmp.upper()))
f.writeln(" bool _lh = !!(lh); \\")
f.writeln(" bool _rh = !!(rh); \\")
f.writeln(" if (!(_lh %s _rh)) { \\" % op)
f.writeln(" __pretty_assert_print( \\")
f.writeln(" __FILE__, __LINE__, \\")
f.writeln(" __pretty_assert_bool, \"%s\", \\" % cmp)
f.writeln(" &_lh, 0, \\")
f.writeln(" &_rh, 0); \\")
f.writeln(" __builtin_trap(); \\")
f.writeln(" } \\")
f.writeln("} while (0)")
f.writeln()
for op, cmp in sorted(CMP.items()):
f.writeln("#define __PRETTY_ASSERT_INT_%s(lh, rh) do { \\" % (
cmp.upper()))
f.writeln(" __typeof__(lh) _lh = lh; \\")
f.writeln(" __typeof__(lh) _rh = rh; \\")
f.writeln(" if (!(_lh %s _rh)) { \\" % op)
f.writeln(" __pretty_assert_print( \\")
f.writeln(" __FILE__, __LINE__, \\")
f.writeln(" __pretty_assert_int, \"%s\", \\" % cmp)
f.writeln(" &(intmax_t){(intmax_t)_lh}, 0, \\")
f.writeln(" &(intmax_t){(intmax_t)_rh}, 0); \\")
f.writeln(" __builtin_trap(); \\")
f.writeln(" } \\")
f.writeln("} while (0)")
f.writeln()
for op, cmp in sorted(CMP.items()):
f.writeln("#define __PRETTY_ASSERT_MEM_%s(lh, rh, size) do { \\" % (
cmp.upper()))
f.writeln(" const void *_lh = lh; \\")
f.writeln(" const void *_rh = rh; \\")
f.writeln(" if (!(memcmp(_lh, _rh, size) %s 0)) { \\" % op)
f.writeln(" __pretty_assert_print( \\")
f.writeln(" __FILE__, __LINE__, \\")
f.writeln(" __pretty_assert_mem, \"%s\", \\" % cmp)
f.writeln(" _lh, size, \\")
f.writeln(" _rh, size); \\")
f.writeln(" __builtin_trap(); \\")
f.writeln(" } \\")
f.writeln("} while (0)")
f.writeln()
for op, cmp in sorted(CMP.items()):
f.writeln("#define __PRETTY_ASSERT_STR_%s(lh, rh) do { \\" % (
cmp.upper()))
f.writeln(" const char *_lh = lh; \\")
f.writeln(" const char *_rh = rh; \\")
f.writeln(" if (!(strcmp(_lh, _rh) %s 0)) { \\" % op)
f.writeln(" __pretty_assert_print( \\")
f.writeln(" __FILE__, __LINE__, \\")
f.writeln(" __pretty_assert_str, \"%s\", \\" % cmp)
f.writeln(" _lh, strlen(_lh), \\")
f.writeln(" _rh, strlen(_rh)); \\")
f.writeln(" __builtin_trap(); \\")
f.writeln(" } \\")
f.writeln("} while (0)")
f.writeln()
f.writeln("#define __PRETTY_ASSERT_UNREACHABLE() do { \\")
f.writeln(" __pretty_assert_print_unreachable( \\")
f.writeln(" __FILE__, __LINE__); \\")
f.writeln(" __builtin_trap(); \\")
f.writeln("} while (0)")
f.writeln()
f.writeln()
def mkassert(f, type, cmp, lh, rh, size=None):
if size is not None:
f.write("__PRETTY_ASSERT_%s_%s(%s, %s, %s)" % (
type.upper(), cmp.upper(), lh, rh, size))
else:
f.write("__PRETTY_ASSERT_%s_%s(%s, %s)" % (
type.upper(), cmp.upper(), lh, rh))
def mkunreachable(f):
f.write("__PRETTY_ASSERT_UNREACHABLE()")
# a simple general-purpose parser class
#
# basically just because memoryview doesn't support strs
class Parser:
def __init__(self, data, ws='\s*', ws_flags=0):
self.data = data
self.i = 0
self.m = None
# also consume whitespace
self.ws = re.compile(ws, ws_flags)
self.i = self.ws.match(self.data, self.i).end()
def __repr__(self):
if len(self.data) - self.i <= 32:
return repr(self.data[self.i:])
else:
return "%s..." % repr(self.data[self.i:self.i+32])[:32]
def __str__(self):
return self.data[self.i:]
def __len__(self):
return len(self.data) - self.i
def __bool__(self):
return self.i != len(self.data)
def match(self, pattern, flags=0):
# compile so we can use the pos arg, this is still cached
self.m = re.compile(pattern, flags).match(self.data, self.i)
return self.m
def group(self, *groups):
return self.m.group(*groups)
def chomp(self, *groups):
g = self.group(*groups)
self.i = self.m.end()
# also consume whitespace
self.i = self.ws.match(self.data, self.i).end()
return g
class Error(Exception):
pass
def chompmatch(self, pattern, flags=0, *groups):
if not self.match(pattern, flags):
raise Parser.Error("expected %r, found %r" % (pattern, self))
return self.chomp(*groups)
def unexpected(self):
raise Parser.Error("unexpected %r" % self)
def lookahead(self):
# push state on the stack
if not hasattr(self, 'stack'):
self.stack = []
self.stack.append((self.i, self.m))
return self
def consume(self):
# pop and use new state
self.stack.pop()
def discard(self):
# pop and discard new state
self.i, self.m = self.stack.pop()
def __enter__(self):
return self
def __exit__(self, et, ev, tb):
# keep new state if no exception occured
if et is None:
self.consume()
else:
self.discard()
# parse rules
def p_assert(p, f):
# assert(memcmp(a,b,size) cmp 0)?
try:
with p.lookahead():
p.chompmatch(L.ASSERT)
p.chompmatch(L.LPAREN)
p.chompmatch(L.MEMCMP)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_expr(p, lh)
lh = lh.getvalue()
p.chompmatch(L.COMMA)
rh = io.StringIO()
p_expr(p, rh)
rh = rh.getvalue()
p.chompmatch(L.COMMA)
size = io.StringIO()
p_expr(p, size)
size = size.getvalue()
p.chompmatch(L.RPAREN)
cmp = p.chompmatch(L.CMP, 0, 'token')
p.chompmatch(L.ZERO)
ws = p.chompmatch(L.RPAREN, 0, 'ws')
mkassert(f, 'mem', CMP[cmp], lh, rh, size)
f.write(ws)
return
except Parser.Error:
pass
# assert(strcmp(a,b) cmp 0)?
try:
with p.lookahead():
p.chompmatch(L.ASSERT)
p.chompmatch(L.LPAREN)
p.chompmatch(L.STRCMP)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_expr(p, lh)
lh = lh.getvalue()
p.chompmatch(L.COMMA)
rh = io.StringIO()
p_expr(p, rh)
rh = rh.getvalue()
p.chompmatch(L.RPAREN)
cmp = p.chompmatch(L.CMP, 0, 'token')
p.chompmatch(L.ZERO)
ws = p.chompmatch(L.RPAREN, 0, 'ws')
mkassert(f, 'str', CMP[cmp], lh, rh)
f.write(ws)
return
except Parser.Error:
pass
# assert(a cmp b)?
try:
with p.lookahead():
p.chompmatch(L.ASSERT)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_simpleexpr(p, lh)
lh = lh.getvalue()
cmp = p.chompmatch(L.CMP, 0, 'token')
rh = io.StringIO()
p_simpleexpr(p, rh)
rh = rh.getvalue()
ws = p.chompmatch(L.RPAREN, 0, 'ws')
mkassert(f, 'int', CMP[cmp], lh, rh)
f.write(ws)
return
except Parser.Error:
pass
# assert(a)?
p.chompmatch(L.ASSERT)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_exprs(p, lh)
lh = lh.getvalue()
ws = p.chompmatch(L.RPAREN, 0, 'ws')
mkassert(f, 'bool', 'eq', lh, 'true')
f.write(ws)
def p_unreachable(p, f):
# unreachable()?
p.chompmatch(L.UNREACHABLE)
p.chompmatch(L.LPAREN)
ws = p.chompmatch(L.RPAREN, 0, 'ws')
mkunreachable(f)
f.write(ws)
def p_simpleexpr(p, f):
while True:
# parens
if p.match(L.LPAREN):
f.write(p.chomp())
# allow terms in parens
while True:
p_exprs(p, f)
if p.match(L.TERM):
f.write(p.chomp())
else:
break
f.write(p.chompmatch(L.RPAREN))
# asserts
elif p.match(L.ASSERT):
try:
with p.lookahead():
p_assert(p, f)
except Parser.Error:
f.write(p.chomp())
# unreachables
elif p.match(L.UNREACHABLE):
try:
with p.lookahead():
p_unreachable(p, f)
except Parser.Error:
f.write(p.chomp())
# anything else
elif p.match(L.STR) or p.match(L.STUFF):
f.write(p.chomp())
else:
break
def p_expr(p, f):
while True:
p_simpleexpr(p, f)
# continue if we hit a complex expr
if p.match(L.CMP) or p.match(L.LOGIC) or p.match(L.TERN):
f.write(p.chomp())
else:
break
def p_exprs(p, f):
while True:
p_expr(p, f)
# continue if we hit a comma
if p.match(L.COMMA):
f.write(p.chomp())
else:
break
def p_stmt(p, f):
# leading whitespace?
if p.match(L.WS):
f.write(p.chomp())
# memcmp(lh,rh,size) => 0?
if p.match(L.MEMCMP):
try:
with p.lookahead():
p.chompmatch(L.MEMCMP)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_expr(p, lh)
lh = lh.getvalue()
p.chompmatch(L.COMMA)
rh = io.StringIO()
p_expr(p, rh)
rh = rh.getvalue()
p.chompmatch(L.COMMA)
size = io.StringIO()
p_expr(p, size)
size = size.getvalue()
p.chompmatch(L.RPAREN)
p.chompmatch(L.ARROW)
ws = p.chompmatch(L.ZERO, 0, 'ws')
mkassert(f, 'mem', 'eq', lh, rh, size)
f.write(ws)
return
except Parse.Error:
pass
# strcmp(lh,rh) => 0?
if p.match(L.STRCMP):
try:
with p.lookahead():
p.chompmatch(L.STRCMP)
p.chompmatch(L.LPAREN)
lh = io.StringIO()
p_expr(p, lh)
lh = lh.getvalue()
p.chompmatch(L.COMMA)
rh = io.StringIO()
p_expr(p, rh)
rh = rh.getvalue()
p.chompmatch(L.RPAREN)
p.chompmatch(L.ARROW)
ws = p.chompmatch(L.ZERO, 0, 'ws')
mkassert(f, 'str', 'eq', lh, rh)
f.write(ws)
return
except Parse.Error:
pass
# lh => rh?
lh = io.StringIO()
p_exprs(p, lh)
lh = lh.getvalue()
if p.match(L.ARROW):
p.chomp()
rh = io.StringIO()
p_exprs(p, rh)
rh = rh.getvalue()
mkassert(f, 'int', 'eq', lh, rh)
else:
f.write(lh)
def main(input=None, output=None, *,
prefix=[],
prefix_insensitive=[],
assert_=[],
unreachable=[],
memcmp=[],
strcmp=[],
no_defaults=False,
no_upper=False,
no_arrows=False,
limit=LIMIT):
# modify lexer rules?
if no_defaults:
L.lex('ASSERT', [])
L.lex('UNREACHABLE', [])
L.lex('MEMCMP', [])
L.lex('STRCMP', [])
for p in prefix + prefix_insensitive:
L.extend('ASSERT', r'\b%sassert\b' % p)
L.extend('UNREACHABLE', r'\b%sunreachable\b' % p)
L.extend('MEMCMP', r'\b%smemcmp\b' % p)
L.extend('STRCMP', r'\b%sstrcmp\b' % p)
for p in prefix_insensitive:
L.extend('ASSERT', r'\b%sassert\b' % p.lower())
L.extend('UNREACHABLE', r'\b%sunreachable\b' % p.lower())
L.extend('MEMCMP', r'\b%smemcmp\b' % p.lower())
L.extend('STRCMP', r'\b%sstrcmp\b' % p.lower())
L.extend('ASSERT', r'\b%sASSERT\b' % p.upper())
L.extend('UNREACHABLE', r'\b%sUNREACHABLE\b' % p.upper())
L.extend('MEMCMP', r'\b%sMEMCMP\b' % p.upper())
L.extend('STRCMP', r'\b%sSTRCMP\b' % p.upper())
if assert_:
L.extend('ASSERT', *[r'\b%s\b' % r for r in assert_])
if unreachable:
L.extend('UNREACHABLE', *[r'\b%s\b' % r for r in unreachable])
if memcmp:
L.extend('MEMCMP', *[r'\b%s\b' % r for r in memcmp])
if strcmp:
L.extend('STRCMP', *[r'\b%s\b' % r for r in strcmp])
# start parsing
with openio(input or '-', 'r') as in_f:
p = Parser(in_f.read(), '')
with openio(output or '-', 'w') as f:
def writeln(s=''):
f.write(s)
f.write('\n')
f.writeln = writeln
# write extra verbose asserts
mkheader(f, limit=limit)
if input is not None:
f.writeln("#line %d \"%s\"" % (1, input))
# parse and write out stmt at a time
try:
while True:
p_stmt(p, f)
if p.match(L.TERM):
f.write(p.chomp())
else:
break
# trailing junk?
if p:
p.unexpected()
except Parser.Error as e:
# warn on error
print('warning: %s' % e)
# still write out the rest of the file so compiler
# errors can be reported, these are usually more useful
f.write(str(p))
if __name__ == "__main__":
import argparse
import sys
parser = argparse.ArgumentParser(
description="Preprocessor that makes asserts easier to debug.",
allow_abbrev=False)
parser.add_argument(
'input',
help="Input C file.")
parser.add_argument(
'-o', '--output',
required=True,
help="Output C file.")
parser.add_argument(
'-p', '--prefix',
action='append',
help="Additional prefixes for symbols.")
parser.add_argument(
'-P', '--prefix-insensitive',
action='append',
help="Additional prefixes for lower/upper case symbol variants.")
parser.add_argument(
'--assert',
dest='assert_',
action='append',
help="Additional symbols for assert statements.")
parser.add_argument(
'--unreachable',
action='append',
help="Additional symbols for unreachable statements.")
parser.add_argument(
'--memcmp',
action='append',
help="Additional symbols for memcmp expressions.")
parser.add_argument(
'--strcmp',
action='append',
help="Additional symbols for strcmp expressions.")
parser.add_argument(
'-n', '--no-defaults',
action='store_true',
help="Disable default symbols.")
parser.add_argument(
'--no-arrows',
action='store_true',
help="Disable arrow (=>) expressions.")
parser.add_argument(
'-l', '--limit',
type=lambda x: int(x, 0),
default=LIMIT,
help="Maximum number of characters to display in strcmp and "
"memcmp. Defaults to %r." % LIMIT)
sys.exit(main(**{k: v
for k, v in vars(parser.parse_intermixed_args()).items()
if v is not None}))