https://github.com/sbrisard/janus
Tip revision: a6196a025fee6bf0f3eb5e636a6b2f895ca6fbc9 authored by Sébastien Brisard on 11 September 2019, 09:52:41 UTC
Typo
Typo
Tip revision: a6196a0
setup.py
# -*- coding: utf-8 -*-
import configparser
import distutils.command.clean
import os
import re
import setuptools
from distutils import log
from distutils.core import setup
from distutils.dir_util import remove_tree
from distutils.extension import Extension
from distutils.util import split_quoted
from Cython.Build import cythonize
NAME = 'Janus'
DESCRIPTION = ''
LONG_DESCRIPTION = ''
AUTHOR = 'S. Brisard'
AUTHOR_EMAIL = 'sebastien.brisard@ifsttar.fr'
URL = 'https://github.com/sbrisard/janus'
DOWNLOAD_URL = 'https://github.com/sbrisard/janus'
LICENSE = 'BSD-3'
class clean(distutils.command.clean.clean):
description = (distutils.command.clean.clean.description +
', including *.c, *.pyc, *.pyd, *.pyo and *.so files')
def find_directories_to_remove(self, root):
directories = []
for dirpath, dirnames, filenames in os.walk(root):
for dirname in dirnames:
if dirname == '__pycache__':
directories.append(os.path.join(dirpath, dirname))
return directories
def find_files_to_remove(self, root):
p = re.compile('.+\.((c)|(so)|(pyc)|(pyd)|(pyo))$')
files = []
for dirpath, dirnames, filenames in os.walk(root):
if not dirpath.endswith('__pycache__'):
for filename in filenames:
if p.match(filename):
files.append(os.path.join(dirpath, filename))
return files
def remove_directories(self, directories):
for d in directories:
remove_tree(d, dry_run=self.dry_run)
def remove_files(self, files):
for f in files:
log.info('removing '+f)
if not self.dry_run:
os.remove(f)
def run(self):
out = super().run()
root = os.path.join('.', 'janus')
directories = self.find_directories_to_remove(root)
files = self.find_files_to_remove(root)
self.remove_directories(directories)
self.remove_files(files)
return out
def extensions_and_packages():
utils = Extension('janus.utils.checkarray',
sources=['janus/utils/checkarray.pyx'])
operators = Extension('janus.operators',
sources=['janus/operators.pyx'])
materials = Extension('janus.material.elastic.linear.isotropic',
sources=['janus/material/elastic/linear/isotropic.pyx'])
config = configparser.ConfigParser()
config.read('setup.cfg')
kwargs = {}
if config.has_section('fftw'):
fftw = config['fftw']
for key in ['include_dirs', 'library_dirs', 'libraries']:
value = fftw.get(key, '')
if value != '':
kwargs[key] = [token.strip() for token in value.split(',')]
serial_fft = Extension('janus.fft.serial._serial_fft',
sources=['janus/fft/serial/_serial_fft.pyx'],
**kwargs)
green = Extension('janus.green', sources=['janus/green.pyx'], **kwargs)
extensions = [utils, operators, materials, green, serial_fft]
packages = ['janus', 'janus.fft', 'janus.fft.serial', 'janus.utils']
return extensions, packages
def mpicc_showme():
"""Use ``mpicc --showme`` to retrieve the mpicc arguments.
Works with openmpi, not mpich.
Returns a dictionary that can be passed to Extension().
"""
import mpi4py
from subprocess import check_output
mpicc = mpi4py.get_config()['mpicc']
def call_mpicc_showme(arg):
out = check_output([mpicc, '--showme:'+arg])
return out.decode('ascii').split()
incdirs = call_mpicc_showme('incdirs')
incdirs.append(mpi4py.get_include())
return {'include_dirs': incdirs,
'library_dirs': call_mpicc_showme('libdirs'),
'extra_compile_args': call_mpicc_showme('compile'),
'extra_link_args': call_mpicc_showme('link')}
def mpicc_show():
"""Use ``mpicc --show`` to retrieve the mpicc arguments.
Works with both openmpi and mpich.
Returns a dictionary that can be passed to Extension().
"""
import mpi4py
import subprocess
mpicc = mpi4py.get_config()['mpicc']
mpicc_show = subprocess.check_output([mpicc, '-show']).decode().strip()
# Strip command line from first part, which is the name of the compiler
mpicc_show = re.sub('\S+\s', '', mpicc_show, count=1)
def my_filter(regex, iterable, group=0):
matching = []
non_matching = []
for item in iterable:
m = re.search(regex, item)
if m is not None:
matching.append(m.group(group))
else:
non_matching.append(item)
return matching, non_matching
cflags = split_quoted(mpicc_show)
incdirs, cflags = my_filter('^-I(.*)', cflags, 1)
libdirs, cflags = my_filter('^-L(.*)', cflags, 1)
ldflags, cflags = my_filter('^-W?l.*', cflags)
ldflags += cflags
incdirs.append(mpi4py.get_include())
return {'include_dirs': incdirs,
'library_dirs': libdirs,
'extra_compile_args': cflags,
'extra_link_args': ldflags}
def extensions_and_packages_with_mpi():
try:
kwargs = mpicc_show()
config = configparser.ConfigParser()
config.read('setup.cfg')
if config.has_section('fftw_mpi'):
fftw_mpi = config['fftw_mpi']
for key in ['include_dirs', 'library_dirs', 'libraries']:
value = fftw_mpi.get(key, '')
if value != '':
kwargs[key] = (kwargs.get(key, [])
+[token.strip() for token in value.split(',')])
parallel_fft = Extension('janus.fft.parallel._parallel_fft',
sources=['janus/fft/parallel/_parallel_fft.pyx'],
**kwargs)
return [parallel_fft], ['janus.fft.parallel']
except ImportError:
return [], []
if __name__ == '__main__':
extensions, packages = extensions_and_packages()
extensions_mpi, packages_mpi = extensions_and_packages_with_mpi()
extensions += extensions_mpi
packages += packages_mpi
setup(name=NAME,
version='0.1',
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
author=AUTHOR,
author_email=AUTHOR_EMAIL,
url=URL,
download_url=DOWNLOAD_URL,
license=LICENSE,
packages=packages,
ext_modules=cythonize(extensions,
compiler_directives={'embedsignature': True,
'language_level': 3}),
cmdclass={'clean': clean})