#! /usr/bin/env python3
# (c) Pierre-Yves Strub - 2014--2020
# --------------------------------------------------------------------
import sys, os, re, yaml, yaml.loader as yloader, schema
import tempfile, fnmatch
# --------------------------------------------------------------------
META = schema.Schema(dict(
entities = schema.Schema({}, ignore_extra_keys = True),
copyrights = [dict(
pattern = [str],
style = str,
license = str,
copyrights = [dict(who = str, date = str)],
)],
))
# --------------------------------------------------------------------
class AttrDict(dict):
def __getattr__(self, name):
return self[name]
class AttrDictSafeConstructor(yloader.SafeConstructor):
def construct_yaml_map(self, node):
data = AttrDict()
yield data
value = self.construct_mapping(node)
data.update(value)
AttrDictSafeConstructor.add_constructor(
u'tag:yaml.org,2002:map', AttrDictSafeConstructor.construct_yaml_map)
class AttrDictSafeLoader(yloader.Reader ,
yloader.Scanner ,
yloader.Parser ,
yloader.Composer ,
AttrDictSafeConstructor,
yloader.Resolver ):
def __init__(self, stream):
yloader.Reader.__init__(self, stream)
yloader.Scanner.__init__(self)
yloader.Parser.__init__(self)
yloader.Composer.__init__(self)
AttrDictSafeConstructor.__init__(self)
yloader.Resolver.__init__(self)
# --------------------------------------------------------------------
class Object:
def __init__(self, **kw):
self.__dict__.update(kw)
# --------------------------------------------------------------------
class CopyrightError(Exception):
pass
# --------------------------------------------------------------------
class Copyright:
COPYRIGHT = 'Copyright (c) - %s - %s'
LICENSE = 'Distributed under the terms of the %s license'
# ----------------------------------------------------------------
def __init__(self, ini):
self.ini = ini
# ----------------------------------------------------------------
def get_copyright(self, filename):
def _norm(x):
return (self.ini['entities'].get(x.who, x.who), x.date)
copyrights = []
for pt in self.ini.copyrights:
for test in pt['pattern']:
if fnmatch.fnmatch(filename, test):
for cp in pt['copyrights']:
thecp = Object(
who = cp['who' ],
date = cp['date' ],
style = pt['style'],
license = pt.get('license', None))
copyrights.append(thecp)
break
styles = sorted(set([x.style for x in copyrights]))
if len(styles) > 1:
raise CopyrightError( \
'%s: multiple styles: %s' % (filename, ', '.join(styles)))
licenses = [x.license for x in copyrights if x.license is not None]
licenses = sorted(set(licenses))
if len(licenses) > 1:
raise CopyrightError( \
'%s: multiple licenses: %s' % (filename, ', '.join(licenses)))
if not copyrights:
return None
return Object(
copyrights = [_norm(x) for x in copyrights],
license = (licenses or [None])[0],
style = (styles or [None])[0])
# ----------------------------------------------------------------
def format(self, copyrights, license = None):
def _cp(x):
return self.COPYRIGHT % (x[1], x[0])
aout = [_cp(x) for x in copyrights]
if license is not None:
aout.append('')
aout.append(self.LICENSE % (license,))
return aout
# --------------------------------------------------------------------
class CopyrightStyle:
STYLES = dict()
def strip(self, contents):
raise RuntimeError
def format(self, contents):
raise RuntimeError
@classmethod
def factory(cls, name):
if name not in cls.STYLES:
raise CopyrightError('unknown style: %s' % (name,))
return cls.STYLES[name]()
# --------------------------------------------------------------------
class CoqCopyrightStyle(CopyrightStyle):
_re = r'^\s*\(\*[^A-Za-z0-9]*copyright.*?\*\)\s*'
def strip(self, contents):
m = re.search(self._re, contents, re.S | re.I)
if m is not None:
contents = contents[m.end():]
return contents
def format(self, contents):
contents = [' * %s' % (x,) for x in contents]
contents = ['(* %s' % ('-' * 68,)] + contents + [' * %s *)' % ('-' * 68,)]
contents = [x.rstrip() for x in contents]
contents = '\n'.join([x.rstrip('\r\n') for x in contents]) + '\n\n'
return contents
# --------------------------------------------------------------------
class ELispCopyrightStyle(CopyrightStyle):
_re = r'^\s*;;\*[^A-Za-z0-9]*copyright.*?;;x'
_re = r'^\s*;;(?:\s|-)*[^A-Za-z0-9]*copyright.*?;;\s*[-]+\s*'
def strip(self, contents):
m = re.search(self._re, contents, re.S | re.I)
if m is not None:
contents = contents[m.end():]
return contents
def format(self, contents):
contents = [';; %s' % (x,) for x in contents]
contents = [';; %s' % ('-' * 68,)] + contents + [';; %s' % ('-' * 68,)]
contents = [x.rstrip() for x in contents]
contents = '\n'.join([x.rstrip('\r\n') for x in contents]) + '\n\n'
return contents
# --------------------------------------------------------------------
CopyrightStyle.STYLES['coq' ] = CoqCopyrightStyle
CopyrightStyle.STYLES['ocaml'] = CoqCopyrightStyle
CopyrightStyle.STYLES['ec' ] = CoqCopyrightStyle
CopyrightStyle.STYLES['elisp'] = ELispCopyrightStyle
# --------------------------------------------------------------------
def _main():
# ----------------------------------------------------------------
if len(sys.argv)-1 < 2:
print("Usage: %s [CONFIG] [FILES...]" % (sys.argv[0],), file=sys.stderr)
exit(1)
inifile = sys.argv[1]
srcfiles = sys.argv[2:]
# ----------------------------------------------------------------
with open(inifile, 'r') as stream:
ini = yaml.load(stream, Loader = AttrDictSafeLoader)
META.validate(ini)
ini = Copyright(ini)
# ----------------------------------------------------------------
for filename in srcfiles:
infos = ini.get_copyright(filename)
if infos is None: continue
style = CopyrightStyle.factory(infos.style)
cp = ini.format(infos.copyrights, infos.license)
cp = style.format(cp)
with open(filename, 'r') as stream:
contents = stream.read()
contents = style.strip(contents)
try: os.unlink(filename + '~')
except OSError: pass
os.rename(filename, filename + '~')
with open(filename, 'w') as stream:
stream.write(cp)
stream.write(contents)
# --------------------------------------------------------------------
if __name__ == '__main__':
_main()