Raw File
license
#! /usr/bin/env python3

# (c) Pierre-Yves Strub - 2014--2020

# --------------------------------------------------------------------
import sys, os, re, yaml, yaml.loader as yloader, schema
import tempfile, fnmatch

# --------------------------------------------------------------------
META = schema.Schema(dict(
    entities   = schema.Schema({}, ignore_extra_keys = True),
    copyrights = [dict(
        pattern    = [str],
        style      = str,
        license    = str,
        copyrights = [dict(who = str, date = str)],
    )],
))

# --------------------------------------------------------------------
class AttrDict(dict):
   def __getattr__(self, name):
       return self[name]

class AttrDictSafeConstructor(yloader.SafeConstructor):
   def construct_yaml_map(self, node):
       data = AttrDict()
       yield data
       value = self.construct_mapping(node)
       data.update(value)

AttrDictSafeConstructor.add_constructor(
  u'tag:yaml.org,2002:map', AttrDictSafeConstructor.construct_yaml_map)

class AttrDictSafeLoader(yloader.Reader         ,
                         yloader.Scanner        ,
                         yloader.Parser         ,
                         yloader.Composer       ,
                         AttrDictSafeConstructor,
                         yloader.Resolver       ):

    def __init__(self, stream):
        yloader.Reader.__init__(self, stream)
        yloader.Scanner.__init__(self)
        yloader.Parser.__init__(self)
        yloader.Composer.__init__(self)
        AttrDictSafeConstructor.__init__(self)
        yloader.Resolver.__init__(self)

# --------------------------------------------------------------------
class Object:
    def __init__(self, **kw):
        self.__dict__.update(kw)

# --------------------------------------------------------------------
class CopyrightError(Exception):
    pass

# --------------------------------------------------------------------
class Copyright:
    COPYRIGHT = 'Copyright (c) - %s - %s'
    LICENSE   = 'Distributed under the terms of the %s license'

    # ----------------------------------------------------------------
    def __init__(self, ini):
        self.ini = ini

    # ----------------------------------------------------------------
    def get_copyright(self, filename):
        def _norm(x):
            return (self.ini['entities'].get(x.who, x.who), x.date)

        copyrights = []
        for pt in self.ini.copyrights:
            for test in pt['pattern']:
                if fnmatch.fnmatch(filename, test):
                    for cp in pt['copyrights']:
                        thecp = Object(
                            who     = cp['who'  ],
                            date    = cp['date' ],
                            style   = pt['style'],
                            license = pt.get('license', None))
                        copyrights.append(thecp)
                    break

        styles = sorted(set([x.style for x in copyrights]))
        if len(styles) > 1:
            raise CopyrightError( \
                '%s: multiple styles: %s' % (filename, ', '.join(styles)))

        licenses = [x.license for x in copyrights if x.license is not None]
        licenses = sorted(set(licenses))
        if len(licenses) > 1:
            raise CopyrightError( \
                '%s: multiple licenses: %s' % (filename, ', '.join(licenses)))

        if not copyrights:
            return None
        return Object(
            copyrights = [_norm(x) for x in copyrights],
            license    = (licenses or [None])[0],
            style      = (styles   or [None])[0])

    # ----------------------------------------------------------------
    def format(self, copyrights, license = None):
        def _cp(x):
            return self.COPYRIGHT % (x[1], x[0])
        aout = [_cp(x) for x in copyrights]
        if license is not None:
            aout.append('')
            aout.append(self.LICENSE % (license,))
        return aout

# --------------------------------------------------------------------
class CopyrightStyle:
    STYLES = dict()

    def strip(self, contents):
        raise RuntimeError

    def format(self, contents):
        raise RuntimeError

    @classmethod
    def factory(cls, name):
        if name not in cls.STYLES:
            raise CopyrightError('unknown style: %s' % (name,))
        return cls.STYLES[name]()

# --------------------------------------------------------------------
class CoqCopyrightStyle(CopyrightStyle):
    _re = r'^\s*\(\*[^A-Za-z0-9]*copyright.*?\*\)\s*'

    def strip(self, contents):
        m = re.search(self._re, contents, re.S | re.I)
        if m is not None:
            contents = contents[m.end():]
        return contents

    def format(self, contents):
        contents = [' * %s' % (x,) for x in contents]
        contents = ['(* %s' % ('-' * 68,)] + contents + [' * %s *)' % ('-' * 68,)]
        contents = [x.rstrip() for x in contents]
        contents = '\n'.join([x.rstrip('\r\n') for x in contents]) + '\n\n'
        return contents

# --------------------------------------------------------------------
class ELispCopyrightStyle(CopyrightStyle):
    _re = r'^\s*;;\*[^A-Za-z0-9]*copyright.*?;;x'
    _re = r'^\s*;;(?:\s|-)*[^A-Za-z0-9]*copyright.*?;;\s*[-]+\s*'

    def strip(self, contents):
        m = re.search(self._re, contents, re.S | re.I)
        if m is not None:
            contents = contents[m.end():]
        return contents

    def format(self, contents):
        contents = [';; %s' % (x,) for x in contents]
        contents = [';; %s' % ('-' * 68,)] + contents + [';; %s' % ('-' * 68,)]
        contents = [x.rstrip() for x in contents]
        contents = '\n'.join([x.rstrip('\r\n') for x in contents]) + '\n\n'
        return contents

# --------------------------------------------------------------------
CopyrightStyle.STYLES['coq'  ] = CoqCopyrightStyle
CopyrightStyle.STYLES['ocaml'] = CoqCopyrightStyle
CopyrightStyle.STYLES['ec'   ] = CoqCopyrightStyle
CopyrightStyle.STYLES['elisp'] = ELispCopyrightStyle

# --------------------------------------------------------------------
def _main():
    # ----------------------------------------------------------------
    if len(sys.argv)-1 < 2:
        print("Usage: %s [CONFIG] [FILES...]" % (sys.argv[0],), file=sys.stderr)
        exit(1)
    inifile  = sys.argv[1]
    srcfiles = sys.argv[2:]

    # ----------------------------------------------------------------
    with open(inifile, 'r') as stream:
        ini = yaml.load(stream, Loader = AttrDictSafeLoader)
    META.validate(ini)
    ini = Copyright(ini)

    # ----------------------------------------------------------------
    for filename in srcfiles:
        infos = ini.get_copyright(filename)
        if infos is None: continue

        style = CopyrightStyle.factory(infos.style)
        cp    = ini.format(infos.copyrights, infos.license)
        cp    = style.format(cp)

        with open(filename, 'r') as stream:
            contents = stream.read()
        contents = style.strip(contents)

        try: os.unlink(filename + '~')
        except OSError: pass
        os.rename(filename, filename + '~')

        with open(filename, 'w') as stream:
            stream.write(cp)
            stream.write(contents)

# --------------------------------------------------------------------
if __name__ == '__main__':
    _main()
back to top