https://github.com/EasyCrypt/easycrypt
Raw File
Tip revision: 30bfa950afa3806948c073d3c9ec4468d33ea940 authored by Pierre-Yves Strub on 11 December 2023, 10:58:49 UTC
New tactic: "proc change"
Tip revision: 30bfa95
runtest
#! /usr/bin/env python3

# --------------------------------------------------------------------
import abc
import asyncio
import collections
import contextlib as clib
import curses
from   datetime import datetime, timezone
import fnmatch
import glob
import itertools
import math
import multiprocessing as mp
import os
import pathlib
import re
import signal
import socket
import sys
import termios
import time

# --------------------------------------------------------------------
try:
    import yaml

except ModuleNotFoundError:
    yaml = None

else:
    class folded_unicode(str):
        pass

    class literal_unicode(str):
        pass

    def folded_unicode_representer(dumper, data):
        return dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='>')
    def literal_unicode_representer(dumper, data):
        return dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|')

    yaml.add_representer(folded_unicode , folded_unicode_representer)
    yaml.add_representer(literal_unicode, literal_unicode_representer)

# --------------------------------------------------------------------
@clib.asynccontextmanager
async def awaitable(the):
    try:
        yield the
    finally:
        await the.wait()

clib.awaitable = awaitable

# --------------------------------------------------------------------
def is_tuple_prefix(prefix, subject):
    return subject[:len(prefix)] == prefix

# --------------------------------------------------------------------
class Object:
    def __init__(self, **kw):
        self.__dict__.update(kw)

# --------------------------------------------------------------------
class Singleton(type):
    _instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]

# --------------------------------------------------------------------
def _options():
    import configparser as cp
    from optparse import OptionParser

    parser = OptionParser()

    parser.add_option(
        '', '--bin',
        action  = 'store',
        metavar = 'BIN',
        default = None,
        help    = 'EasyCrypt binary to use')

    parser.add_option(
        '', '--bin-args',
        action  = 'append',
        metavar = 'ARGS',
        default = [],
        help    = 'append ARGS to EasyCrypt command (cumulative)')

    parser.add_option(
        '', '--timeout',
        action  = 'store',
        default = None,
        metavar = 'TIMEOUT',
        type    = 'int',
        help    = 'set the timeout option to pass to EasyCrypt')

    parser.add_option(
        '', '--jobs',
        action  = 'store',
        default = None,
        metavar = 'JOBS',
        type    = 'int',
        help    = 'number of maximum parallel test jobs')

    parser.add_option(
        '', '--timing',
        action  = 'store_true',
        default = False,
        help    = 'add timing statistics')

    parser.add_option(
        '-p', '--provers',
        action  = 'append',
        metavar = 'PROVERS',
        default = None,
        help    = 'Add a prover to the set of activated provers')

    parser.add_option(
        '', '--why3',
        action  = 'store',
        metavar = 'WHY3',
        default = None,
        help    = 'Why3 configuration file')

    parser.add_option(
        '', '--report',
        action  = 'store',
        default = None,
        metavar = 'FILE',
        help    = 'dump result to FILE')

    (cmdopt, args) = parser.parse_args()

    if len(args) < 1:
        parser.error('this program takes at least one argument')

    if cmdopt.timeout is not None:
        if cmdopt.timeout <= 0:
            parser.error('timeout must be positive')

    if cmdopt.jobs is not None:
        if cmdopt.jobs < 0:
            parser.error('jobs must be non-negative')

    options = Object(scenarios = dict())
    options.timeout  = cmdopt.timeout
    options.timing   = cmdopt.timing
    options.why3     = cmdopt.why3
    options.jobs     = cmdopt.jobs
    options.provers  = cmdopt.provers
    options.report   = cmdopt.report

    defaults = dict(
        args         = '',
        exclude      = '',
        file_exclude = '',
        okdirs       = '',
        kodirs       = '',
    )

    config = cp.ConfigParser(defaults)
    config.read(args[0])

    def resolve_targets(names):
        targets = []
        for name in names:
            if name.startswith('!'):
                targets = filter(lambda x : x != name[1:], targets)
            else:
                if name not in targets:
                    targets.append(name)
        return targets

    options.bin = cmdopt.bin

    if options.bin is None:
        options.bin = config.get('default', 'bin', fallback = 'easycrypt')

    options.args    = config.get('default', 'args', fallback = '').split()
    options.targets = resolve_targets(args[1:])

    if options.report is None:
        options.report = config.get('default', 'report', fallback = None)

    if cmdopt.bin_args:
        options.args.extend(itertools.chain.from_iterable( \
          x.split() for x in cmdopt.bin_args))

    def _parse_timeout(x):
        m = re.search(r'^(.*):(\d+)$', x)
        if m is None:
            parse.parseerror('invalid timeout: %s' % (x,))
        return (m.group(1), int(m.group(2)))

    for test in [x for x in config.sections() if x.startswith('test-')]:
        scenario = Object()
        scenario.args     = config.get(test, 'args').split()
        scenario.okdirs   = config.get(test, 'okdirs')
        scenario.kodirs   = config.get(test, 'kodirs')
        scenario.dexclude = config.get(test, 'exclude').split()
        scenario.fexclude = config.get(test, 'file_exclude').split()
        options.scenarios[test[5:]] = scenario

    for x in options.targets:
        if x not in options.scenarios:
            parser.error('unknown scenario: %s' % (x,))

    return options

# --------------------------------------------------------------------
class Listener(abc.ABC):
    REFRESH = None

    @abc.abstractmethod
    def file_check_start(self, filename):
        pass

    @abc.abstractmethod
    def file_check_progress(self, handle, status):
        pass

    @abc.abstractmethod
    def file_check_done(self, handle, status):
        pass

    @abc.abstractmethod
    def error(self, handle, level, msg):
        pass

    @abc.abstractmethod
    def tick(self):
        pass

    @abc.abstractmethod
    def start(self):
        pass

# --------------------------------------------------------------------
class CurseWrapper(metaclass = Singleton):
    COLOR_RED    = curses.COLOR_RED
    COLOR_GREEN  = curses.COLOR_GREEN
    COLOR_YELLOW = curses.COLOR_YELLOW
    COLOR_BLUE   = curses.COLOR_BLUE

    COLOR_CODES = (
        COLOR_RED   ,
        COLOR_GREEN ,
        COLOR_YELLOW,
        COLOR_BLUE  ,
    )

    def __init__(self):
        curses.setupterm()
        self._codes = dict(
            cup   = curses.tigetstr('cuu1' ),
            cdw   = curses.tigetstr('cud1' ),
            sv    = curses.tigetstr('sc'   ),
            cr    = curses.tigetstr('rc'   ),
            el    = curses.tigetstr('el'   ),
            setaf = curses.tigetstr('setaf'),
            sgr0  = curses.tigetstr('sgr0' ),
        )

        self._colors = {
            x: curses.tparm(self.codes['setaf'], x) for x in self.COLOR_CODES
        }

        self._autoflush = 0
        self._stream    = sys.stdout

    codes  = property(lambda self : self._codes)
    colors = property(lambda self : self._colors)

    def acquire(self):
        self._autoflush += 1

    def release(self):
        assert (self._autoflush > 0)

        self._autoflush -= 1
        if self._autoflush == 0:
            self._stream.flush()

    def rawwrite(self, data, flush = None):
        self._stream.buffer.write(data)
        if flush or self._autoflush == 0:
            self._stream.flush()

    def write(self, data, flush = None):
        self.rawwrite(data.encode('utf-8'), flush)

    def cup(self, n = 1):
        self.rawwrite(self.codes['cup'] * n)

    def cdown(self, n = 1):
        self.rawwrite(self.codes['cdw'] * n)

    def csave(self):
        self.rawwrite(self.codes['sv'])

    def crestore(self):
        self.rawwrite(self.codes['cr'])

    def clearline(self):
        self.rawwrite(self.codes['el'])

    def colored(self, txt, color):
        icolor = self.colors[color].decode('ascii')
        ocolor = self.codes['sgr0'].decode('ascii')
        return f'{icolor}{txt}{ocolor}'

    @clib.contextmanager
    def cursor(self):
        self.csave()
        try:
            yield
        finally:
           self.crestore()

    def echooff(self):
        fdesc = self._stream.fileno()
        attrs = termios.tcgetattr(fdesc)
        attrs[3] &= ~termios.ECHO
        termios.tcsetattr(fdesc, termios.TCSADRAIN, attrs)

    def echoon(self):
        fdesc = self._stream.fileno()
        attrs = termios.tcgetattr(fdesc)
        attrs[3] |= termios.ECHO
        termios.tcsetattr(fdesc, termios.TCSADRAIN, attrs)

# --------------------------------------------------------------------
class Gather:
    @staticmethod
    def is_dir_excluded(src, excludes):
        src = pathlib.Path(src).parts
        for path in excludes:
            rec = False
            if path.startswith('!'):
                rec, path = True, path[1:]
            path = pathlib.Path(path).parts
            if (is_tuple_prefix(path, src) if rec else path == src):
                return True
        return False

    @staticmethod
    def is_file_excluded(src, excludes):
        for exclude in excludes:
            print(os.path.basename(src), exclude)
            if fnmatch.fnmatch(src, exclude):
                return True
        return False

    @classmethod
    def gather(cls, obj, scenario):
        try:
            scripts = os.listdir(obj.src)
        except OSError as e:
            logging.warning("cannot scan `%s': %s" % (obj.src, e))
            return []
        scripts = sorted([x for x in scripts if re.search(r'\.eca?$', x)])

        def config(filename):
            fullname = os.path.join(obj.src, filename)
            return Object(isvalid  = obj.valid,
                          group    = obj.src,
                          args     = obj.args,
                          filename = fullname)

        return [config(x) for x in scripts]

    @classmethod
    def gather_for_scenario(cls, scenario):
        def expand(dirs):
            def for1(x):
                aout = []
                if x.startswith('!'):
                    aout.append(x[1:])
                    for root, dnames, _ in os.walk(x[1:]):
                        aout.extend([os.path.join(root, x) for x in dnames])
                else:
                    aout.extend(glob.glob(x))
                return aout

            dirs = [for1(x) for x in re.split(r'\s+', dirs)]
            return list(itertools.chain.from_iterable(dirs))

        dirs = []
        dirs.extend([Object(src = x, valid = True , args = scenario.args) \
                         for x in expand(scenario.okdirs)])
        dirs.extend([Object(src = x, valid = False, args = scenario.args) \
                         for x in expand(scenario.kodirs)])
        dirs = [x for x in dirs if not cls.is_dir_excluded(x.src, scenario.dexclude)]
        dirs = map(lambda x : cls.gather(x, scenario), dirs)

        files = list(itertools.chain.from_iterable(dirs))
        files = [x for x in files if not cls.is_file_excluded(x.filename, scenario.fexclude)]

        return files

    @classmethod
    def gatherall(cls, scenarios, targets):
        dirs = [scenarios[x] for x in targets]
        dirs = map(lambda x : cls.gather_for_scenario(x), dirs)
        return list(itertools.chain.from_iterable(dirs))

# ---------------------------------------------------------------------
def format_delta(delta, hprec=2, prec=0, subprec=0):
    if prec > 0:
        frac, delta = math.modf(delta)
        delta, frac = int(delta), int(10 ** prec * frac)
        if subprec > 0:
            subprec = 10 ** prec // subprec
            frac = frac // subprec * subprec
        return '{1:{0}d}:{2:02d}.{4:0{3}d}'.format(
            hprec, delta // 60, delta % 60, prec, frac)

    return '{1:{0}d}:{2:02d}'.format(hprec, int(delta) // 60, int(delta) % 60)

# ---------------------------------------------------------------------
class MBar:
    def __init__(self, ntasks):
        self.bars      = []
        self.term      = CurseWrapper()
        self.ntasks    = ntasks
        self.stats     = dict(success = 0, failure = 0, running = 0, waiting = ntasks)
        self.timestamp = time.time()

        with self.term.cursor():
            self.term.rawwrite(self.headline())
        self.term.cdown()

    def headline(self):
        headline = '⁕ [{}] Success: {}, Failure: {}, Running: {}, Waiting: {}'
        headline = headline.format(
            format_delta(time.time() - self.timestamp),
            *[self.term.colored(str(self.stats[x]), c) for x, c in [
                ('success', CurseWrapper.COLOR_GREEN ),
                ('failure', CurseWrapper.COLOR_RED   ),
                ('running', CurseWrapper.COLOR_YELLOW),
                ('waiting', CurseWrapper.COLOR_BLUE  ),
            ]])
        return headline.encode('utf-8')

    def create(self, name):
        with self:
            assert (self.stats['waiting'] > 0)

            self._clear_bars_display()
            self.stats['waiting'] -= 1
            self.stats['running'] += 1
            self.bars.append(MBarLine(self, name))
            self._create_bars_display()

            return self.bars[-1]

    def write_for(self, bar):
        with self:
            index = self.bars.index(bar)
            txt   = self.render(bar)
            with self.term.cursor():
                self.term.cup(len(self.bars) - index)
                self.term.clearline()
                self.term.rawwrite(txt)

    def write(self, txt, mark = True):
        with self:
            self._clear_bars_display()
            self.term.rawwrite(
                (('▶ ' if mark else '  ') + txt).encode('utf-8') + b'\n'
            )
            self._create_bars_display()

    def log(self, mark, txt, color):
        with self:
            if isinstance(txt, str):
                txt = [txt]
            self.write('[{}] {}'.format(self.term.colored(mark, color), txt[0]))
            for subtxt in txt[1:]:
                self.write('    {}'.format(subtxt), mark = False)

    def remove(self, bar, success):
        with self:
            assert (self.stats['running'] > 0)

            self._clear_bars_display()
            self.stats['running'] -= 1
            self.stats['success' if success else 'failure'] += 1
            self.bars[:] = [x for x in self.bars if x is not bar]
            self._create_bars_display()

    def render(self, bar):
        return '⚙ '.encode('utf-8') + bar.render()

    def refresh(self):
        with self:
            self._clear_bars_display()
            self._create_bars_display()

    def __enter__(self):
        self.term.acquire()

    def __exit__(self, exc_type, exc_value, traceback):
        self.term.release()

    def _clear_bars_display(self):
        for _ in range(len(self.bars)+1):
            self.term.cup()
            self.term.clearline()

    def _create_bars_display(self):
        with self.term.cursor():
            self.term.rawwrite(self.headline())
        self.term.cdown()
        for bar in self.bars:
            with self.term.cursor():
                self.term.rawwrite(self.render(bar))
            self.term.cdown()

# --------------------------------------------------------------------
class MBarLine:
    BLACK = '█'
    BLANK = '-'
    WIDTH = 50

    def __init__(self, mbar, name):
        self.mbar  = mbar
        self.value = 0.0
        self.name  = name
        self.now   = time.time()
        self.last  = time.time()

    def update(self, value):
        self.value = max(min(float(value), 1.0), 0.0)
        self.last  = time.time()
        self.mbar.write_for(self)

    def render(self):
        wtg = time.time() - self.last
        wtg = ' ' * 7 if wtg < 1.0 else ('+' + format_delta(wtg, hprec=0, prec=1))
        bar = round(self.WIDTH * self.value)
        bar = (self.BLACK * bar) + (self.BLANK * (self.WIDTH - bar))
        bar = '[{:6.2f}%] {} [{} {}] [{}]'.format(
            100 * self.value, bar,
            format_delta(self.last - self.now), wtg, self.name)

        return bar.encode('utf-8')

    def finish(self, success):
        self.mbar.remove(self, success)

# --------------------------------------------------------------------
MARKS = {
    0:
        ('✓', CurseWrapper.COLOR_GREEN),
    -signal.SIGINT:
        ('ϟ', CurseWrapper.COLOR_YELLOW),
    -signal.SIGTERM:
        ('ϟ', CurseWrapper.COLOR_YELLOW),
    None:
        ('✗', CurseWrapper.COLOR_RED),
}

# --------------------------------------------------------------------
LOGMARKS = {
    'critical':
        ('✗', CurseWrapper.COLOR_RED),
    'warning':
        ('ϟ', CurseWrapper.COLOR_YELLOW),
    'info':
        ('ⓘ', CurseWrapper.COLOR_BLUE),
    'debug':
        ('ⓘ', CurseWrapper.COLOR_BLUE),
    None:
        ('✗', CurseWrapper.COLOR_RED),
}

# --------------------------------------------------------------------
class TermListener(Listener):
    REFRESH = 0.1

    def __init__(self, allscripts):
        self._mbar = MBar(len(allscripts))
        self._handles = {}

    def file_check_start(self, filename):
        handle = Object(
            filename  = filename,
            timestamp = time.time(),
            bar       = self._mbar.create(filename),
        )

        self._handles[id(handle)] = handle
        return id(handle)

    def file_check_progress(self, handle, progress):
        handle = self._handles[handle]
        handle.bar.update(progress)

    def file_check_done(self, handle, status, success):
        handle = self._handles[handle]

        mark, color = MARKS.get(status, MARKS[None])
        self._mbar.log(mark, '[{}] {}'.format(
            format_delta(time.time() - handle.timestamp, prec=1), handle.filename
        ), color)
        handle.bar.finish(success = success)

        del self._handles[id(handle)]

    def tick(self):
        self._mbar.refresh()

    def error(self, handle, level, msg):
        handle = self._handles[handle]
        mark, color = LOGMARKS.get(level, LOGMARKS[None])
        msg = msg.replace('\\n', '\n').split('\n')
        self._mbar.log(mark, [handle.filename] + msg, color)

    @clib.contextmanager
    def start(self):
        self._mbar.term.echooff()
        yield
        self._mbar.term.echoon()


# --------------------------------------------------------------------
class RawListener(Listener):
    def __init__(self, allscripts):
        self._handles = {}
        self._total   = len(allscripts)
        self._checked = 0
        self.log(f'[ ] number of files to check: {self._total}')

    def log(self, msg):
        sys.stdout.write(msg + '\n')
        sys.stdout.flush()

    def file_check_start(self, filename):
        handle = Object(
            filename  = filename,
            timestamp = time.time(),
        )

        self._handles[id(handle)] = handle
        return id(handle)

    def file_check_progress(self, handle, progress):
        pass

    def file_check_done(self, handle, status, success):
        handle = self._handles[handle]
        self._checked += 1
        mark, _ = MARKS.get(status, MARKS[None])
        delta = format_delta(time.time() - handle.timestamp, prec=1)
        self.log(f'[{mark}] [{self._checked:04d}/{self._total:04d}] [{delta}] {handle.filename}')
        del self._handles[id(handle)]

    def error(self, handle, level, msg):
        handle, (mark, _) = self._handles[handle], LOGMARKS.get(level, LOGMARKS[None])
        msg = msg.strip('\n').split('\n') or ['']
        if len(msg) > 1:
            self.log('[{}] {}:'.format(mark, handle.filename))
            for submsg in msg:
                self.log('    {}'.format(submsg))
        else:
            self.log('[{}] {}: {}'.format(mark, handle.filename, msg[0]))

    def tick(self):
        pass

    @clib.contextmanager
    def start(self):
        yield

# --------------------------------------------------------------------
def get_njobs():
    try:
        njobs = max(int(os.environ.get('ECRJOBS', 0)), 0)
    except ValueError:
        return 0

    return njobs or mp.cpu_count()

# --------------------------------------------------------------------
async def _run_all(options, allscripts, listener : Listener):
    semaphore = asyncio.Semaphore(options.jobs or get_njobs())

    async def runcheck(config):
        async with semaphore:
            command = [options.bin] + options.args + config.args
            if options.timeout:
                command.extend(['-timeout', str(options.timeout)])
            if options.provers:
                for prover in options.provers:
                    command.extend(['-p', prover])
            if options.why3:
                command.extend(['-why3', options.why3])
            if options.timing:
                tfilename = os.path.join(
                    TIMING_DIR, os.path.splitext(config.filename)[0] + '.stats')
                os.makedirs(os.path.dirname(tfilename), exist_ok = True)
                command.extend(['-tstats', tfilename])

            handle = listener.file_check_start(config.filename)

            command.extend(['-script', '-no-eco', config.filename])

            args = dict(
                stdout = asyncio.subprocess.DEVNULL,
                stderr = asyncio.subprocess.PIPE,
            )

            time0  = time.time()
            proc   = await asyncio.create_subprocess_exec(*command, **args)
            errors = []

            try:
                async with clib.awaitable(proc) as proc:
                    async for line in proc.stderr:
                        if re.match(rb'P\b', line):
                            listener.file_check_progress(handle, float(line.split()[3]))
                        elif re.match(rb'E\b', line):
                            error = line[1:].decode('utf-8').strip()
                            if not error:
                                continue
                            if (m := re.search(r'^(\S+)\s+(.+)$', error)):
                                code, error = m.group(1), m.group(2)
                            else:
                                code = None
                            listener.error(handle, code, error)
                            errors.append(error)

            finally:
                try:
                    status = await asyncio.wait_for(proc.wait(), 2.0)

                except asyncio.TimeoutError:
                    proc.kill(); status = await proc.wait()

                success = (bool(status) != bool(config.isvalid))

                listener.file_check_done(handle, status, success)

            return Object(success  = success,
                          config   = config ,
                          errors   = errors ,
                          duration = time.time() - time0)

    with listener.start():
        log  = []
        jobs = [asyncio.ensure_future(runcheck(cfg)) for cfg in allscripts]

        while jobs:
            done, jobs = await asyncio.wait(
                jobs,
                timeout = listener.REFRESH,
                return_when = asyncio.FIRST_COMPLETED)

            listener.tick()

            for aout in done:
                try:
                    log.append(aout.result())
                except KeyboardInterrupt:
                    pass

    return log

# --------------------------------------------------------------------
def _dump_report(config, results, stream):
    duration = sum([x.duration for x in results])
    grouped  = dict()
    aout     = []

    for result in results:
        grouped.setdefault(result.config.group, []).append(result)

    for gname, group in grouped.items():
        ok   = [x for x in group if     x.success]
        ko   = [x for x in group if not x.success]
        node = {}

        node['name']      = gname
        node['hostname']  = config.hostname
        node['timestamp'] = config.timestamp.isoformat()
        node['tests']     = len(group)
        node['failures']  = len(ko)
        node['time']      = '%.3f' % duration
        node['details']   = []

        for result in group:
            subnode = {}

            name = os.path.basename(result.config.filename)
            name = os.path.splitext(name)[0]
            name = '%s (%s)' % (name, result.config.filename)

            subnode['name']       = name
            subnode['time']       = '%.3f' % (result.duration,)
            subnode['success']    = result.success
            subnode['shouldpass'] = result.config.isvalid
            subnode['errors']     = [literal_unicode(x) for x in result.errors]

            node['details'].append(subnode)

        aout.append(node)

    opts = dict(
        default_flow_style = None   ,
        encoding           = 'utf-8',
        sort_keys          = False  ,
    )

    stream.write(yaml.dump(aout, **opts))

# --------------------------------------------------------------------
async def _main():
    mainconfig = Object()

    mainconfig.hostname  = socket.gethostname()
    mainconfig.timestamp = datetime.now(timezone.utc)

    options = _options()

    if options.report is not None:
        if yaml is None:
            print(
                'reporting needs the yaml Python module',
                file = sys.stderr
            )
            exit(1)

    allscripts = Gather.gatherall(options.scenarios, options.targets)

    listener = None
    if sys.stdout.isatty:
        try:
            listener = TermListener(allscripts)
        except curses.error:
            pass
    if listener is None:
        listener = RawListener(allscripts)

    log = await _run_all(options, allscripts, listener)

    if options.report is not None:
        with open(options.report, 'wb') as output:
            _dump_report(mainconfig, log, output)

    haserrors = any(not x.success for x in log)

    exit(2 if haserrors else 0)

# --------------------------------------------------------------------
if __name__ == '__main__':
    try:
        asyncio.run(_main())
    except KeyboardInterrupt:
        pass
back to top