#!/usr/bin/env python3 # # Plot CSV files with matplotlib. # # Example: # ./scripts/plotmpl.py bench.csv -xSIZE -ybench_read -obench.svg # # Copyright (c) 2022, The littlefs authors. # SPDX-License-Identifier: BSD-3-Clause # import codecs import collections as co import csv import io import itertools as it import logging import math as m import numpy as np import os import shlex import shutil import time import matplotlib as mpl import matplotlib.pyplot as plt # some nicer colors borrowed from Seaborn # note these include a non-opaque alpha COLORS = [ '#4c72b0bf', # blue '#dd8452bf', # orange '#55a868bf', # green '#c44e52bf', # red '#8172b3bf', # purple '#937860bf', # brown '#da8bc3bf', # pink '#8c8c8cbf', # gray '#ccb974bf', # yellow '#64b5cdbf', # cyan ] COLORS_DARK = [ '#a1c9f4bf', # blue '#ffb482bf', # orange '#8de5a1bf', # green '#ff9f9bbf', # red '#d0bbffbf', # purple '#debb9bbf', # brown '#fab0e4bf', # pink '#cfcfcfbf', # gray '#fffea3bf', # yellow '#b9f2f0bf', # cyan ] ALPHAS = [0.75] FORMATS = ['-'] FORMATS_POINTS = ['.'] FORMATS_POINTS_AND_LINES = ['.-'] WIDTH = 750 HEIGHT = 350 FONT_SIZE = 11 SI_PREFIXES = { 18: 'E', 15: 'P', 12: 'T', 9: 'G', 6: 'M', 3: 'K', 0: '', -3: 'm', -6: 'u', -9: 'n', -12: 'p', -15: 'f', -18: 'a', } SI2_PREFIXES = { 60: 'Ei', 50: 'Pi', 40: 'Ti', 30: 'Gi', 20: 'Mi', 10: 'Ki', 0: '', -10: 'mi', -20: 'ui', -30: 'ni', -40: 'pi', -50: 'fi', -60: 'ai', } # formatter for matplotlib def si(x): if x == 0: return '0' # figure out prefix and scale p = 3*int(m.log(abs(x), 10**3)) p = min(18, max(-18, p)) # format with 3 digits of precision s = '%.3f' % (abs(x) / (10.0**p)) s = s[:3+1] # truncate but only digits that follow the dot if '.' in s: s = s.rstrip('0') s = s.rstrip('.') return '%s%s%s' % ('-' if x < 0 else '', s, SI_PREFIXES[p]) # formatter for matplotlib def si2(x): if x == 0: return '0' # figure out prefix and scale p = 10*int(m.log(abs(x), 2**10)) p = min(30, max(-30, p)) # format with 3 digits of precision s = '%.3f' % (abs(x) / (2.0**p)) s = s[:3+1] # truncate but only digits that follow the dot if '.' in s: s = s.rstrip('0') s = s.rstrip('.') return '%s%s%s' % ('-' if x < 0 else '', s, SI2_PREFIXES[p]) # parse escape strings def escape(s): return codecs.escape_decode(s.encode('utf8'))[0].decode('utf8') # we want to use MaxNLocator, but since MaxNLocator forces multiples of 10 # to be an option, we can't really... class AutoMultipleLocator(mpl.ticker.MultipleLocator): def __init__(self, base, nbins=None): # note base needs to be floats to avoid integer pow issues self.base = float(base) self.nbins = nbins super().__init__(self.base) def __call__(self): # find best tick count, conveniently matplotlib has a function for this vmin, vmax = self.axis.get_view_interval() vmin, vmax = mpl.transforms.nonsingular(vmin, vmax, 1e-12, 1e-13) if self.nbins is not None: nbins = self.nbins else: nbins = np.clip(self.axis.get_tick_space(), 1, 9) # find the best power, use this as our locator's actual base scale = self.base ** (m.ceil(m.log((vmax-vmin) / (nbins+1), self.base))) self.set_params(scale) return super().__call__() def openio(path, mode='r', buffering=-1): # allow '-' for stdin/stdout if path == '-': if 'r' in mode: return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering) else: return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering) else: return open(path, mode, buffering) # parse different data representations def dat(x): # allow the first part of an a/b fraction if '/' in x: x, _ = x.split('/', 1) # first try as int try: return int(x, 0) except ValueError: pass # then try as float try: return float(x) # just don't allow infinity or nan if m.isinf(x) or m.isnan(x): raise ValueError("invalid dat %r" % x) except ValueError: pass # else give up raise ValueError("invalid dat %r" % x) def collect(csv_paths, renames=[], defines=[]): # collect results from CSV files fields = [] results = [] for path in csv_paths: try: with openio(path) as f: reader = csv.DictReader(f, restval='') fields.extend( k for k in reader.fieldnames if k not in fields) for r in reader: # apply any renames if renames: # make a copy so renames can overlap r_ = {} for new_k, old_k in renames: if old_k in r: r_[new_k] = r[old_k] r.update(r_) # filter by matching defines if not all(k in r and r[k] in vs for k, vs in defines): continue results.append(r) except FileNotFoundError: pass return fields, results def fold(results, by=None, x=None, y=None, defines=[]): # filter by matching defines if defines: results_ = [] for r in results: if all(k in r and r[k] in vs for k, vs in defines): results_.append(r) results = results_ if by: # find all 'by' values keys = set() for r in results: keys.add(tuple(r.get(k, '') for k in by)) keys = sorted(keys) # collect all datasets datasets = co.OrderedDict() for key in (keys if by else [()]): for x_ in (x if x else [None]): for y_ in y: # organize by 'by', x, and y dataset = [] i = 0 for r in results: # filter by 'by' if by and not all( k in r and r[k] == v for k, v in zip(by, key)): continue # find xs if x_ is not None: if x_ not in r: continue try: x__ = dat(r[x_]) except ValueError: continue else: # fallback to enumeration x__ = i i += 1 # find ys if y_ is not None: if y_ not in r: continue try: y__ = dat(r[y_]) except ValueError: continue else: y__ = None dataset.append((x__, y__)) # hide x/y if there is only one field k_x = x_ if len(x or []) > 1 else '' k_y = y_ if len(y or []) > 1 or (not key and not k_x) else '' datasets[key + (k_x, k_y)] = dataset return datasets # some classes for organizing subplots into a grid class Subplot: def __init__(self, **args): self.x = 0 self.y = 0 self.xspan = 1 self.yspan = 1 self.args = args class Grid: def __init__(self, subplot, width=1.0, height=1.0): self.xweights = [width] self.yweights = [height] self.map = {(0,0): subplot} self.subplots = [subplot] def __repr__(self): return 'Grid(%r, %r)' % (self.xweights, self.yweights) @property def width(self): return len(self.xweights) @property def height(self): return len(self.yweights) def __iter__(self): return iter(self.subplots) def __getitem__(self, i): x, y = i if x < 0: x += len(self.xweights) if y < 0: y += len(self.yweights) return self.map[(x,y)] def merge(self, other, dir): if dir in ['above', 'below']: # first scale the two grids so they line up self_xweights = self.xweights other_xweights = other.xweights self_w = sum(self_xweights) other_w = sum(other_xweights) ratio = self_w / other_w other_xweights = [s*ratio for s in other_xweights] # now interleave xweights as needed new_xweights = [] self_map = {} other_map = {} self_i = 0 other_i = 0 self_xweight = (self_xweights[self_i] if self_i < len(self_xweights) else m.inf) other_xweight = (other_xweights[other_i] if other_i < len(other_xweights) else m.inf) while self_i < len(self_xweights) and other_i < len(other_xweights): if other_xweight - self_xweight > 0.0000001: new_xweights.append(self_xweight) other_xweight -= self_xweight new_i = len(new_xweights)-1 for j in range(len(self.yweights)): self_map[(new_i, j)] = self.map[(self_i, j)] for j in range(len(other.yweights)): other_map[(new_i, j)] = other.map[(other_i, j)] for s in other.subplots: if s.x+s.xspan-1 == new_i: s.xspan += 1 elif s.x > new_i: s.x += 1 self_i += 1 self_xweight = (self_xweights[self_i] if self_i < len(self_xweights) else m.inf) elif self_xweight - other_xweight > 0.0000001: new_xweights.append(other_xweight) self_xweight -= other_xweight new_i = len(new_xweights)-1 for j in range(len(other.yweights)): other_map[(new_i, j)] = other.map[(other_i, j)] for j in range(len(self.yweights)): self_map[(new_i, j)] = self.map[(self_i, j)] for s in self.subplots: if s.x+s.xspan-1 == new_i: s.xspan += 1 elif s.x > new_i: s.x += 1 other_i += 1 other_xweight = (other_xweights[other_i] if other_i < len(other_xweights) else m.inf) else: new_xweights.append(self_xweight) new_i = len(new_xweights)-1 for j in range(len(self.yweights)): self_map[(new_i, j)] = self.map[(self_i, j)] for j in range(len(other.yweights)): other_map[(new_i, j)] = other.map[(other_i, j)] self_i += 1 self_xweight = (self_xweights[self_i] if self_i < len(self_xweights) else m.inf) other_i += 1 other_xweight = (other_xweights[other_i] if other_i < len(other_xweights) else m.inf) # squish so ratios are preserved self_h = sum(self.yweights) other_h = sum(other.yweights) ratio = (self_h-other_h) / self_h self_yweights = [s*ratio for s in self.yweights] # finally concatenate the two grids if dir == 'above': for s in other.subplots: s.y += len(self_yweights) self.subplots.extend(other.subplots) self.xweights = new_xweights self.yweights = self_yweights + other.yweights self.map = self_map | {(x, y+len(self_yweights)): s for (x, y), s in other_map.items()} else: for s in self.subplots: s.y += len(other.yweights) self.subplots.extend(other.subplots) self.xweights = new_xweights self.yweights = other.yweights + self_yweights self.map = other_map | {(x, y+len(other.yweights)): s for (x, y), s in self_map.items()} if dir in ['right', 'left']: # first scale the two grids so they line up self_yweights = self.yweights other_yweights = other.yweights self_h = sum(self_yweights) other_h = sum(other_yweights) ratio = self_h / other_h other_yweights = [s*ratio for s in other_yweights] # now interleave yweights as needed new_yweights = [] self_map = {} other_map = {} self_i = 0 other_i = 0 self_yweight = (self_yweights[self_i] if self_i < len(self_yweights) else m.inf) other_yweight = (other_yweights[other_i] if other_i < len(other_yweights) else m.inf) while self_i < len(self_yweights) and other_i < len(other_yweights): if other_yweight - self_yweight > 0.0000001: new_yweights.append(self_yweight) other_yweight -= self_yweight new_i = len(new_yweights)-1 for j in range(len(self.xweights)): self_map[(j, new_i)] = self.map[(j, self_i)] for j in range(len(other.xweights)): other_map[(j, new_i)] = other.map[(j, other_i)] for s in other.subplots: if s.y+s.yspan-1 == new_i: s.yspan += 1 elif s.y > new_i: s.y += 1 self_i += 1 self_yweight = (self_yweights[self_i] if self_i < len(self_yweights) else m.inf) elif self_yweight - other_yweight > 0.0000001: new_yweights.append(other_yweight) self_yweight -= other_yweight new_i = len(new_yweights)-1 for j in range(len(other.xweights)): other_map[(j, new_i)] = other.map[(j, other_i)] for j in range(len(self.xweights)): self_map[(j, new_i)] = self.map[(j, self_i)] for s in self.subplots: if s.y+s.yspan-1 == new_i: s.yspan += 1 elif s.y > new_i: s.y += 1 other_i += 1 other_yweight = (other_yweights[other_i] if other_i < len(other_yweights) else m.inf) else: new_yweights.append(self_yweight) new_i = len(new_yweights)-1 for j in range(len(self.xweights)): self_map[(j, new_i)] = self.map[(j, self_i)] for j in range(len(other.xweights)): other_map[(j, new_i)] = other.map[(j, other_i)] self_i += 1 self_yweight = (self_yweights[self_i] if self_i < len(self_yweights) else m.inf) other_i += 1 other_yweight = (other_yweights[other_i] if other_i < len(other_yweights) else m.inf) # squish so ratios are preserved self_w = sum(self.xweights) other_w = sum(other.xweights) ratio = (self_w-other_w) / self_w self_xweights = [s*ratio for s in self.xweights] # finally concatenate the two grids if dir == 'right': for s in other.subplots: s.x += len(self_xweights) self.subplots.extend(other.subplots) self.xweights = self_xweights + other.xweights self.yweights = new_yweights self.map = self_map | {(x+len(self_xweights), y): s for (x, y), s in other_map.items()} else: for s in self.subplots: s.x += len(other.xweights) self.subplots.extend(other.subplots) self.xweights = other.xweights + self_xweights self.yweights = new_yweights self.map = other_map | {(x+len(other.xweights), y): s for (x, y), s in self_map.items()} def scale(self, width, height): self.xweights = [s*width for s in self.xweights] self.yweights = [s*height for s in self.yweights] @classmethod def fromargs(cls, width=1.0, height=1.0, *, subplots=[], **args): grid = cls(Subplot(**args)) for dir, subargs in subplots: subgrid = cls.fromargs( width=subargs.pop('width', 0.5 if dir in ['right', 'left'] else width), height=subargs.pop('height', 0.5 if dir in ['above', 'below'] else height), **subargs) grid.merge(subgrid, dir) grid.scale(width, height) return grid def main(csv_paths, output, *, svg=False, png=False, quiet=False, by=None, x=None, y=None, define=[], points=False, points_and_lines=False, colors=None, formats=None, labels=None, width=WIDTH, height=HEIGHT, xlim=(None,None), ylim=(None,None), xlog=False, ylog=False, x2=False, y2=False, xticks=None, yticks=None, xunits=None, yunits=None, xlabel=None, ylabel=None, xticklabels=None, yticklabels=None, title=None, legend_right=False, legend_above=False, legend_below=False, dark=False, ggplot=False, xkcd=False, github=False, font=None, font_size=FONT_SIZE, font_color=None, foreground=None, background=None, subplot={}, subplots=[], **args): # guess the output format if not png and not svg: if output.endswith('.png'): png = True else: svg = True # some shortcuts for color schemes if github: ggplot = True if font_color is None: if dark: font_color = '#c9d1d9' else: font_color = '#24292f' if foreground is None: if dark: foreground = '#343942' else: foreground = '#eff1f3' if background is None: if dark: background = '#0d1117' else: background = '#ffffff' # what colors/alphas/formats to use? if colors is not None: colors_ = colors elif dark: colors_ = COLORS_DARK else: colors_ = COLORS if formats is not None: formats_ = formats elif points_and_lines: formats_ = FORMATS_POINTS_AND_LINES elif points: formats_ = FORMATS_POINTS else: formats_ = FORMATS if labels is not None: labels_ = labels else: labels_ = [None] if font_color is not None: font_color_ = font_color elif dark: font_color_ = '#ffffff' else: font_color_ = '#000000' if foreground is not None: foreground_ = foreground elif dark: foreground_ = '#333333' else: foreground_ = '#e5e5e5' if background is not None: background_ = background elif dark: background_ = '#000000' else: background_ = '#ffffff' # configure some matplotlib settings if xkcd: # the font search here prints a bunch of unhelpful warnings logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR) plt.xkcd() # turn off the white outline, this breaks some things plt.rc('path', effects=[]) if ggplot: plt.style.use('ggplot') plt.rc('patch', linewidth=0) plt.rc('axes', facecolor=foreground_, edgecolor=background_) plt.rc('grid', color=background_) # fix the the gridlines when ggplot+xkcd if xkcd: plt.rc('grid', linewidth=1) plt.rc('axes.spines', bottom=False, left=False) if dark: plt.style.use('dark_background') plt.rc('savefig', facecolor='auto', edgecolor='auto') # fix ggplot when dark if ggplot: plt.rc('axes', facecolor=foreground_, edgecolor=background_) plt.rc('grid', color=background_) if font is not None: plt.rc('font', family=font) plt.rc('font', size=font_size) plt.rc('text', color=font_color_) plt.rc('figure', titlesize='medium', labelsize='small') plt.rc('axes', titlesize='small', labelsize='small', labelcolor=font_color_) if not ggplot: plt.rc('axes', edgecolor=font_color_) plt.rc('xtick', labelsize='small', color=font_color_) plt.rc('ytick', labelsize='small', color=font_color_) plt.rc('legend', fontsize='small', fancybox=False, framealpha=None, edgecolor=foreground_, borderaxespad=0) plt.rc('axes.spines', top=False, right=False) plt.rc('figure', facecolor=background_, edgecolor=background_) if not ggplot: plt.rc('axes', facecolor='#00000000') # I think the svg backend just ignores DPI, but seems to use something # equivalent to 96, maybe this is the default for SVG rendering? plt.rc('figure', dpi=96) # subplot can also contribute to subplots, resolve this here or things # become a mess... subplots += subplot.pop('subplots', []) # allow any subplots to contribute to by/x/y/defines def subplots_get(k, *, subplots=[], **args): v = args.get(k, []).copy() for _, subargs in subplots: v.extend(subplots_get(k, **subargs)) return v all_by = (by or []) + subplots_get('by', **subplot, subplots=subplots) all_x = (x or []) + subplots_get('x', **subplot, subplots=subplots) all_y = (y or []) + subplots_get('y', **subplot, subplots=subplots) all_defines = co.defaultdict(lambda: set()) for k, vs in it.chain(define or [], subplots_get('define', **subplot, subplots=subplots)): all_defines[k] |= vs all_defines = sorted(all_defines.items()) # separate out renames all_renames = list(it.chain.from_iterable( ((k, v) for v in vs) for k, vs in it.chain(all_by, all_x, all_y))) all_by = [k for k, _ in all_by] all_x = [k for k, _ in all_x] all_y = [k for k, _ in all_y] if not all_by and not all_y: print("error: needs --by or -y to figure out fields") sys.exit(-1) # first collect results from CSV files fields_, results = collect(csv_paths, all_renames, all_defines) # if y not specified, guess it's anything not in by/defines/x/renames if not all_y: all_y = [ k for k in fields_ if k not in all_by and not any(k == k_ for k_, _ in all_defines) and not any(k == old_k for _, old_k in all_renames)] # then extract the requested datasets datasets_ = fold(results, all_by, all_x, all_y) # figure out formats/colors/labels here so that subplot defines # don't change them later, that'd be bad dataformats_ = { name: formats_[i % len(formats_)] for i, name in enumerate(datasets_.keys())} datacolors_ = { name: colors_[i % len(colors_)] for i, name in enumerate(datasets_.keys())} datalabels_ = { name: labels_[i % len(labels_)] for i, name in enumerate(datasets_.keys())} # create a grid of subplots grid = Grid.fromargs(**subplot, subplots=subplots) # create a matplotlib plot fig = plt.figure(figsize=( width/plt.rcParams['figure.dpi'], height/plt.rcParams['figure.dpi']), layout='constrained', # we need a linewidth to keep xkcd mode happy linewidth=8 if xkcd else 0) gs = fig.add_gridspec( grid.height + (1 if legend_above else 0) + (1 if legend_below else 0), grid.width + (1 if legend_right else 0), height_ratios=([0.001] if legend_above else []) + [max(s, 0.01) for s in reversed(grid.yweights)] + ([0.001] if legend_below else []), width_ratios=[max(s, 0.01) for s in grid.xweights] + ([0.001] if legend_right else [])) # first create axes so that plots can interact with each other for s in grid: s.ax = fig.add_subplot(gs[ grid.height-(s.y+s.yspan) + (1 if legend_above else 0) : grid.height-s.y + (1 if legend_above else 0), s.x : s.x+s.xspan]) # now plot each subplot for s in grid: # allow subplot params to override global params x_ = {k for k,_ in (x or []) + s.args.get('x', [])} y_ = {k for k,_ in (y or []) + s.args.get('y', [])} define_ = define + s.args.get('define', []) xlim_ = s.args.get('xlim', xlim) ylim_ = s.args.get('ylim', ylim) xlog_ = s.args.get('xlog', False) or xlog ylog_ = s.args.get('ylog', False) or ylog x2_ = s.args.get('x2', False) or x2 y2_ = s.args.get('y2', False) or y2 xticks_ = s.args.get('xticks', xticks) yticks_ = s.args.get('yticks', yticks) xunits_ = s.args.get('xunits', xunits) yunits_ = s.args.get('yunits', yunits) xticklabels_ = s.args.get('xticklabels', xticklabels) yticklabels_ = s.args.get('yticklabels', yticklabels) # label/titles are handled a bit differently in subplots subtitle = s.args.get('title') xsublabel = s.args.get('xlabel') ysublabel = s.args.get('ylabel') # allow shortened ranges if len(xlim_) == 1: xlim_ = (0, xlim_[0]) if len(ylim_) == 1: ylim_ = (0, ylim_[0]) # data can be constrained by subplot-specific defines, # so re-extract for each plot subdatasets = fold(results, all_by, all_x, all_y, define_) # filter by subplot x/y subdatasets = co.OrderedDict([(name, dataset) for name, dataset in subdatasets.items() if not name[-2] or name[-2] in x_ if not name[-1] or name[-1] in y_]) # plot! ax = s.ax for name, dataset in subdatasets.items(): dats = sorted((x,y) for x,y in dataset) ax.plot([x for x,_ in dats], [y for _,y in dats], dataformats_[name], color=datacolors_[name], label=','.join(k for k in name if k)) # axes scaling if xlog_: ax.set_xscale('symlog') ax.xaxis.set_minor_locator(mpl.ticker.NullLocator()) if ylog_: ax.set_yscale('symlog') ax.yaxis.set_minor_locator(mpl.ticker.NullLocator()) # axes limits ax.set_xlim( xlim_[0] if xlim_[0] is not None else min(it.chain([0], (x for dataset in subdatasets.values() for x, y in dataset if y is not None))), xlim_[1] if xlim_[1] is not None else max(it.chain([0], (x for r in subdatasets.values() for x, y in dataset if y is not None)))) ax.set_ylim( ylim_[0] if ylim_[0] is not None else min(it.chain([0], (y for dataset in subdatasets.values() for _, y in dataset if y is not None))), ylim_[1] if ylim_[1] is not None else max(it.chain([0], (y for dataset in subdatasets.values() for _, y in dataset if y is not None)))) # axes ticks if x2_: ax.xaxis.set_major_formatter(lambda x, pos: si2(x)+(xunits_ if xunits_ else '')) if xticklabels_ is not None: ax.xaxis.set_ticklabels(xticklabels_) if xticks_ is None: ax.xaxis.set_major_locator(AutoMultipleLocator(2)) elif isinstance(xticks_, list): ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks_)) elif xticks_ != 0: ax.xaxis.set_major_locator(AutoMultipleLocator(2, xticks_-1)) else: ax.xaxis.set_major_locator(mpl.ticker.NullLocator()) else: ax.xaxis.set_major_formatter(lambda x, pos: si(x)+(xunits_ if xunits_ else '')) if xticklabels_ is not None: ax.xaxis.set_ticklabels(xticklabels_) if xticks_ is None: ax.xaxis.set_major_locator(mpl.ticker.AutoLocator()) elif isinstance(xticks_, list): ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks_)) elif xticks_ != 0: ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(xticks_-1)) else: ax.xaxis.set_major_locator(mpl.ticker.NullLocator()) if y2_: ax.yaxis.set_major_formatter(lambda x, pos: si2(x)+(yunits_ if yunits_ else '')) if yticklabels_ is not None: ax.yaxis.set_ticklabels(yticklabels_) if yticks_ is None: ax.yaxis.set_major_locator(AutoMultipleLocator(2)) elif isinstance(yticks_, list): ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(yticks_)) elif yticks_ != 0: ax.yaxis.set_major_locator(AutoMultipleLocator(2, yticks_-1)) else: ax.yaxis.set_major_locator(mpl.ticker.NullLocator()) else: ax.yaxis.set_major_formatter(lambda x, pos: si(x)+(yunits_ if yunits_ else '')) if yticklabels_ is not None: ax.yaxis.set_ticklabels(yticklabels_) if yticks_ is None: ax.yaxis.set_major_locator(mpl.ticker.AutoLocator()) elif isinstance(yticks_, list): ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(yticks_)) elif yticks_ != 0: ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(yticks_-1)) else: ax.yaxis.set_major_locator(mpl.ticker.NullLocator()) if ggplot: ax.grid(sketch_params=None) # axes subplot labels if xsublabel is not None: ax.set_xlabel(escape(xsublabel)) if ysublabel is not None: ax.set_ylabel(escape(ysublabel)) if subtitle is not None: ax.set_title(escape(subtitle)) # add a legend? a bit tricky with matplotlib # # the best solution I've found is a dedicated, invisible axes for the # legend, hacky, but it works. # # note this was written before constrained_layout supported legend # collisions, hopefully this is added in the future legend = {} for s in grid: for h, l in zip(*s.ax.get_legend_handles_labels()): legend[l] = h # sort in dataset order legend_ = [] for name in datasets_.keys(): name_ = ','.join(k for k in name if k) if name_ in legend: if datalabels_[name] is None: legend_.append((name_, legend[name_])) elif datalabels_[name]: legend_.append((datalabels_[name], legend[name_])) legend = legend_ if legend_right: ax = fig.add_subplot(gs[(1 if legend_above else 0):,-1]) ax.set_axis_off() ax.legend( [h for _,h in legend], [l for l,_ in legend], loc='upper left', fancybox=False, borderaxespad=0) if legend_above: ax = fig.add_subplot(gs[0, :grid.width]) ax.set_axis_off() # try different column counts until we fit in the axes for ncol in reversed(range(1, len(legend)+1)): # permute the labels, mpl wants to order these column first nrow = m.ceil(len(legend)/ncol) legend_ = ncol*nrow*[None] for x in range(ncol): for y in range(nrow): if x+ncol*y < len(legend): legend_[x*nrow+y] = legend[x+ncol*y] legend_ = [l for l in legend_ if l is not None] legend_ = ax.legend( [h for _,h in legend_], [l for l,_ in legend_], loc='upper center', ncol=ncol, fancybox=False, borderaxespad=0) if (legend_.get_window_extent().width <= ax.get_window_extent().width): break if legend_below: ax = fig.add_subplot(gs[-1, :grid.width]) ax.set_axis_off() # big hack to get xlabel above the legend! but hey this # works really well actually if xlabel: ax.set_title(escape(xlabel), size=plt.rcParams['axes.labelsize'], weight=plt.rcParams['axes.labelweight']) # try different column counts until we fit in the axes for ncol in reversed(range(1, len(legend)+1)): # permute the labels, mpl wants to order these column first nrow = m.ceil(len(legend)/ncol) legend_ = ncol*nrow*[None] for x in range(ncol): for y in range(nrow): if x+ncol*y < len(legend): legend_[x*nrow+y] = legend[x+ncol*y] legend_ = [l for l in legend_ if l is not None] legend_ = ax.legend( [h for _,h in legend_], [l for l,_ in legend_], loc='upper center', ncol=ncol, fancybox=False, borderaxespad=0) if (legend_.get_window_extent().width <= ax.get_window_extent().width): break # axes labels, NOTE we reposition these below if xlabel is not None and not legend_below: fig.supxlabel(escape(xlabel)) if ylabel is not None: fig.supylabel(escape(ylabel)) if title is not None: fig.suptitle(escape(title)) # precompute constrained layout and find midpoints to adjust things # that should be centered so they are actually centered fig.canvas.draw() xmid = (grid[0,0].ax.get_position().x0 + grid[-1,0].ax.get_position().x1)/2 ymid = (grid[0,0].ax.get_position().y0 + grid[0,-1].ax.get_position().y1)/2 if xlabel is not None and not legend_below: fig.supxlabel(escape(xlabel), x=xmid) if ylabel is not None: fig.supylabel(escape(ylabel), y=ymid) if title is not None: fig.suptitle(escape(title), x=xmid) # write the figure! plt.savefig(output, format='png' if png else 'svg') # some stats if not quiet: print('updated %s, %s datasets, %s points' % ( output, len(datasets_), sum(len(dataset) for dataset in datasets_.values()))) if __name__ == "__main__": import sys import argparse import re parser = argparse.ArgumentParser( description="Plot CSV files with matplotlib.", allow_abbrev=False) parser.add_argument( 'csv_paths', nargs='*', help="Input *.csv files.") output_rule = parser.add_argument( '-o', '--output', required=True, help="Output *.svg/*.png file.") parser.add_argument( '--svg', action='store_true', help="Output an svg file. By default this is infered.") parser.add_argument( '--png', action='store_true', help="Output a png file. By default this is infered.") parser.add_argument( '-q', '--quiet', action='store_true', help="Don't print info.") parser.add_argument( '-b', '--by', action='append', type=lambda x: ( lambda k, vs=None: ( k.strip(), tuple(v.strip() for v in vs.split(',')) if vs is not None else ()) )(*x.split('=', 1)), help="Group by this field. Can rename fields with new_name=old_name.") parser.add_argument( '-x', action='append', type=lambda x: ( lambda k, vs=None: ( k.strip(), tuple(v.strip() for v in vs.split(',')) if vs is not None else ()) )(*x.split('=', 1)), help="Field to use for the x-axis. Can rename fields with " "new_name=old_name.") parser.add_argument( '-y', action='append', type=lambda x: ( lambda k, vs=None: ( k.strip(), tuple(v.strip() for v in vs.split(',')) if vs is not None else ()) )(*x.split('=', 1)), help="Field to use for the y-axis. Can rename fields with " "new_name=old_name.") parser.add_argument( '-D', '--define', type=lambda x: ( lambda k, vs: ( k.strip(), {v.strip() for v in vs.split(',')}) )(*x.split('=', 1)), action='append', help="Only include results where this field is this value. May include " "comma-separated options.") parser.add_argument( '-.', '--points', action='store_true', help="Only draw data points.") parser.add_argument( '-!', '--points-and-lines', action='store_true', help="Draw data points and lines.") parser.add_argument( '--colors', type=lambda x: [x.strip() for x in x.split(',')], help="Comma-separated hex colors to use.") parser.add_argument( '--formats', type=lambda x: [x.strip().replace('\,',',') for x in re.split(r'(?