https://gitlab.com/januseriksen/pymbe
Raw File
Tip revision: f9ef68e11c1218281025a2f307d627cf3a239838 authored by Janus Juul Eriksen on 18 November 2019, 22:38:48 UTC
Merge branch 'master' into dev
Tip revision: f9ef68e
tools.py
#!/usr/bin/env python
# -*- coding: utf-8 -*

"""
tools module containing all helper functions used in pymbe
"""

__author__ = 'Dr. Janus Juul Eriksen, University of Bristol, UK'
__license__ = 'MIT'
__version__ = '0.8'
__maintainer__ = 'Dr. Janus Juul Eriksen'
__email__ = 'janus.eriksen@bristol.ac.uk'
__status__ = 'Development'

import os
import re
import sys
import traceback
import subprocess
import numpy as np
import scipy.special
import functools
import itertools
import math
from typing import Tuple, List, Union

import parallel

# restart folder
RST = os.getcwd()+'/rst'
# pi-orbitals
PI_SYMM_D2H = np.array([2, 3, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17, 20, 21, 22, 23, 24, 25, 26, 27])
PI_SYMM_C2V = np.array([10, 11, 12, 13, 20, 21, 22, 23])


class Logger:
        """
        this class pipes all write statements to both stdout and output_file
        """
        def __init__(self, output_file: str, both: bool = True) -> None:
            """
            init Logger
            """
            self.terminal = sys.stdout
            self.log = open(output_file, 'a')
            self.both = both

        def write(self, message: str) -> None:
            """
            define write
            """
            self.log.write(message)
            if self.both:
                self.terminal.write(message)

        def flush(self):
            """
            define flush
            """
            pass


def git_version() -> str:
        """
        this function returns the git revision as a string
        see: https://github.com/numpy/numpy/blob/master/setup.py#L70-L92
        """
        def _minimal_ext_cmd(cmd: List[str]) -> bytes:
            env = {}
            for k in ['SYSTEMROOT', 'PATH', 'HOME']:
                v = os.environ.get(k)
                if v is not None:
                    env[k] = v
            # LANGUAGE is used on win32
            env['LANGUAGE'] = 'C'
            env['LANG'] = 'C'
            env['LC_ALL'] = 'C'
            out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env, cwd=get_pymbe_path()).communicate()[0]
            return out

        try:
            out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
            GIT_REVISION = out.strip().decode('ascii')
        except OSError:
            GIT_REVISION = "Unknown"

        return GIT_REVISION


def get_pymbe_path() -> str:
        """
        this function returns the path to pymbe
        """
        return os.path.dirname(os.path.realpath(sys.argv[0]))


def assertion(cond: bool, reason: str) -> None:
        """
        this function returns an assertion of a given condition
        """
        if not cond:

            # get stack
            stack = ''.join(traceback.format_stack()[:-1])

            # print stack
            print('\n\n'+stack)
            print('\n\n*** PyMBE assertion error: '+reason+' ***\n\n')

            # abort mpi
            parallel.abort()


def time_str(time: float) -> str:
        """
        this function returns time as a HH:MM:SS string

        example:
        >>> time_str(3742.4)
        '1h 2m 22.40s'
        """
        # hours, minutes, and seconds
        hours = time // 3600.
        minutes = (time - (time // 3600) * 3600.) // 60.
        seconds = time - hours * 3600. - minutes * 60.

        # init time string
        string: str = ''
        form: Tuple[float, ...] = ()

        # write time string
        if hours > 0:

            string += '{:.0f}h '
            form += (hours,)

        if minutes > 0:

            string += '{:.0f}m '
            form += (minutes,)

        string += '{:.2f}s'
        form += (seconds,)

        return string.format(*form)


def fsum(a: np.ndarray) -> Union[float, np.ndarray]:
        """
        this function uses math.fsum to safely sum 1d array or 2d array (column-wise)

        example:
        >>> np.isclose(fsum(np.arange(10.)), 45.)
        True
        >>> np.allclose(fsum(np.arange(4. ** 2).reshape(4, 4)), np.array([24., 28., 32., 36.]))
        True
        """
        if a.ndim == 1:
            return math.fsum(a)
        elif a.ndim == 2:
            return np.fromiter(map(math.fsum, a.T), dtype=a.dtype, count=a.shape[1])
        else:
            raise NotImplementedError('tools.py: _fsum()')


def hash_2d(a: np.ndarray) -> np.ndarray:
        """
        this function converts a 2d numpy array to a 1d array of hashes

        example:
        >>> hash_2d(np.arange(4 * 4, dtype=np.int16).reshape(4, 4))
        array([-2930228190932741801,  1142744019865853604, -8951855736587463849,
                4559082070288058232])
        """
        return np.fromiter(map(hash_1d, a), dtype=np.int64, count=a.shape[0])


def hash_1d(a: np.ndarray) -> int:
        """
        this function converts a 1d numpy array to a hash

        example:
        >>> hash_1d(np.arange(5, dtype=np.int16))
        1974765062269638978
        """
        return hash(a.astype(np.int64).tobytes())


def hash_compare(a: np.ndarray, b: np.ndarray) -> Union[np.ndarray, None]:
        """
        this function finds occurences of b in a through a binary search

        example:
        >>> a = np.arange(10, dtype=np.int16)
        >>> hash_compare(a, np.array([1, 3, 5, 7, 9], dtype=np.int16))
        array([1, 3, 5, 7, 9])
        >>> hash_compare(a, np.array([1, 3, 5, 7, 11], dtype=np.int16)) is None
        True
        """
        left = a.searchsorted(b, side='left')
        right = a.searchsorted(b, side='right')

        if ((right - left) > 0).all():
            return left
        else:
            return None


def cas(ref_space: np.ndarray, tup: np.ndarray) -> np.ndarray:
        """
        this function returns a cas space

        example:
        >>> cas(np.array([7, 13]), np.arange(5))
        array([ 0,  1,  2,  3,  4,  7, 13])
        """
        return np.sort(np.append(ref_space, tup))


def core_cas(nocc: int, ref_space: np.ndarray, tup: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        this function returns a core and a cas space

        example:
        >>> core_cas(8, np.arange(3, 5), np.array([9, 21]))
        (array([0, 1, 2, 5, 6, 7]), array([ 3,  4,  9, 21]))
        """
        cas_idx = cas(ref_space, tup)
        core_idx = np.setdiff1d(np.arange(nocc), cas_idx)
        return core_idx, cas_idx


def _cas_idx_cart(cas_idx: np.ndarray) -> np.ndarray:
        """
        this function returns a cartesian product of (cas_idx, cas_idx)

        example:
        >>> _cas_idx_cart(np.arange(0, 10, 3))
        array([[0, 0],
               [0, 3],
               [0, 6],
               [0, 9],
               [3, 0],
               [3, 3],
               [3, 6],
               [3, 9],
               [6, 0],
               [6, 3],
               [6, 6],
               [6, 9],
               [9, 0],
               [9, 3],
               [9, 6],
               [9, 9]])
        """
        return np.array(np.meshgrid(cas_idx, cas_idx)).T.reshape(-1, 2)


def _coor_to_idx(ij: Tuple[int, int]) -> int:
        """
        this function returns the lower triangular index corresponding to (i, j)

        example:
        >>> _coor_to_idx((4, 9))
        49
        """
        i = ij[0]; j = ij[1]
        if i >= j:
            return i * (i + 1) // 2 + j
        else:
            return j * (j + 1) // 2 + i


def cas_idx_tril(cas_idx: np.ndarray) -> np.ndarray:
        """
        this function returns lower triangular cas indices

        example:
        >>> cas_idx_tril(np.arange(2, 14, 3))
        array([ 5, 17, 20, 38, 41, 44, 68, 71, 74, 77])
        """
        cas_idx_cart = _cas_idx_cart(cas_idx)
        return np.unique(np.fromiter(map(_coor_to_idx, cas_idx_cart), \
                                        dtype=cas_idx_cart.dtype, count=cas_idx_cart.shape[0]))


def pi_space(group: str, orbsym: np.ndarray, exp_space: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        this function returns pi-orbitals and hashes from total expansion space

        example:
        >>> orbsym_dooh = np.array([14, 15, 5, 2, 3, 5, 0, 11, 10, 7, 6, 5, 3, 2, 0, 14, 15, 5])
        >>> exp_space = np.arange(18, dtype=np.int16)
        >>> pi_pairs, pi_hashes = pi_space('dooh', orbsym_dooh, exp_space)
        >>> pi_pairs_ref = np.array([12, 13,  7,  8,  3,  4,  0,  1,  9, 10, 15, 16], dtype=np.int16)
        >>> np.allclose(pi_pairs, pi_pairs_ref)
        True
        >>> pi_hashes_ref = np.array([-8471304755370577665, -7365615264797734692, -3932386661120954737,
        ...                           -3821038970866580488,  758718848004794914,   7528999078095043310])
        >>> np.allclose(pi_hashes, pi_hashes_ref)
        True
        """
        # all pi-orbitals
        if group == 'Dooh':
            pi_space_arr = exp_space[np.in1d(orbsym[exp_space], PI_SYMM_D2H)]
        else:
            pi_space_arr = exp_space[np.in1d(orbsym[exp_space], PI_SYMM_C2V)]

        # get all degenerate pi-pairs
        pi_pairs = pi_space_arr.reshape(-1, 2)

        # get hashes of all degenerate pi-pairs
        pi_hashes = hash_2d(pi_pairs)
        pi_pairs = pi_pairs[np.argsort(pi_hashes)]
        pi_hashes.sort()

        return pi_pairs.reshape(-1,), pi_hashes


def non_deg_orbs(pi_space: np.ndarray, tup: np.ndarray) -> np.ndarray:
        """
        this function returns non-degenerate orbitals from tuple of orbitals

        example:
        >>> non_deg_orbs(np.array([1, 2, 4, 5], dtype=np.int16), np.arange(8, dtype=np.int16))
        array([0, 3, 6, 7], dtype=int16)
        """
        return tup[np.invert(np.in1d(tup, pi_space))]


def _pi_orbs(pi_space: np.ndarray, tup: np.ndarray) -> np.ndarray:
        """
        this function returns pi-orbitals from tuple of orbitals

        example:
        >>> _pi_orbs(np.array([1, 2, 4, 5], dtype=np.int16), np.arange(8, dtype=np.int16))
        array([1, 2, 4, 5], dtype=int16)
        """
        return tup[np.in1d(tup, pi_space)]


def n_pi_orbs(pi_space: np.ndarray, tup: np.ndarray) -> np.ndarray:
        """
        this function returns number of pi-orbitals in tuple of orbitals

        example:
        >>> n_pi_orbs(np.array([1, 2, 4, 5], dtype=np.int16), np.arange(8, dtype=np.int16))
        4
        """
        return _pi_orbs(pi_space, tup).size


def pi_pairs_deg(pi_space: np.ndarray, tup: np.ndarray) -> np.ndarray:
        """
        this function returns pairs of degenerate pi-orbitals from tuple of orbitals

        example:
        >>> pi_pairs_deg(np.array([1, 2, 4, 5], dtype=np.int16), np.arange(8, dtype=np.int16))
        array([[1, 2],
               [4, 5]], dtype=int16)
        """
        # get all pi-orbitals in tup
        tup_pi_orbs = _pi_orbs(pi_space, tup)

        # return degenerate pairs
        if tup_pi_orbs.size % 2 > 0:
            return tup_pi_orbs[1:].reshape(-1, 2)
        else:
            return tup_pi_orbs.reshape(-1, 2)


def pi_prune(pi_space: np.ndarray, pi_hashes: np.ndarray, tup: np.ndarray) -> bool:
        """
        this function returns True for a tuple of orbitals allowed under pruning wrt degenerate pi-orbitals

        example:
        >>> pi_space = np.array([1, 2, 4, 5], dtype=np.int16)
        >>> pi_hashes = np.sort(np.array([-2163557957507198923, 1937934232745943291]))
        >>> pi_prune(pi_space, pi_hashes, np.array([0, 1, 2, 4, 5, 6, 7], dtype=np.int16))
        True
        >>> pi_prune(pi_space, pi_hashes, np.array([0, 1, 2], dtype=np.int16))
        True
        >>> pi_prune(pi_space, pi_hashes, np.array([0, 1, 2, 4], dtype=np.int16))
        False
        >>> pi_prune(pi_space, pi_hashes, np.array([0, 1, 2, 5, 6], dtype=np.int16))
        False
        """
        # get all pi-orbitals in tup
        tup_pi_orbs = _pi_orbs(pi_space, tup)

        if tup_pi_orbs.size == 0:

            # no pi-orbitals
            return True

        else:

            if tup_pi_orbs.size % 2 > 0:

                # always prune tuples with an odd number of pi-orbitals
                return False

            else:

                # get hashes of pi-pairs
                tup_pi_hashes = hash_2d(tup_pi_orbs.reshape(-1, 2))
                tup_pi_hashes.sort()

                # get indices of pi-pairs
                idx = hash_compare(pi_hashes, tup_pi_hashes)

                return idx is not None


def occ_prune(occup: np.ndarray, tup: np.ndarray) -> bool:
        """
        this function returns True for a tuple of orbitals allowed under pruning wrt occupied orbitals

        example:
        >>> occup = np.array([2.] * 3 + [0.] * 4)
        >>> occ_prune(occup, np.arange(2, 7, dtype=np.int16))
        True
        >>> occ_prune(occup, np.arange(3, 7, dtype=np.int16))
        False
        """
        return np.any(occup[tup] > 0.)


def virt_prune(occup: np.ndarray, tup: np.ndarray) -> bool:
        """
        this function returns True for a tuple of orbitals allowed under pruning wrt virtual orbitals

        example:
        >>> occup = np.array([2.] * 3 + [0.] * 4)
        >>> virt_prune(occup, np.arange(1, 4, dtype=np.int16))
        True
        >>> virt_prune(occup, np.arange(1, 3, dtype=np.int16))
        False
        """
        return np.any(occup[tup] == 0.)


def nelec(occup: np.ndarray, tup: np.ndarray) -> Tuple[int, int]:
        """
        this function returns the number of electrons in a given tuple of orbitals

        example:
        >>> occup = np.array([2.] * 3 + [0.] * 4)
        >>> nelec(occup, np.array([2, 4], dtype=np.int16))
        (1, 1)
        >>> nelec(occup, np.array([3, 4], dtype=np.int16))
        (0, 0)
        """
        occup_tup = occup[tup]
        return (np.count_nonzero(occup_tup > 0.), np.count_nonzero(occup_tup > 1.))


def ndets(occup: np.ndarray, cas_idx: np.ndarray, \
            ref_space: np.ndarray = None, n_elec: Tuple[int, ...] = None) -> int:
        """
        this function returns the number of determinants in given casci calculation (ignoring point group symmetry)

        example:
        >>> occup = np.array([2.] * 3 + [0.] * 4)
        >>> ndets(occup, np.arange(1, 5, dtype=np.int16))
        36
        >>> ndets(occup, np.arange(1, 7, dtype=np.int16),
        ...       ref_space=np.array([1, 2], dtype=np.int16))
        4900
        >>> ndets(occup, np.arange(1, 7, 2, dtype=np.int16),
        ...       ref_space=np.array([1, 3], dtype=np.int16),
        ...       n_elec=(1, 1))
        100
        """
        if n_elec is None:
            n_elec = nelec(occup, cas_idx)

        n_orbs = cas_idx.size

        if ref_space is not None:
            ref_n_elec = nelec(occup, ref_space)
            n_elec = tuple(map(sum, zip(n_elec, ref_n_elec)))
            n_orbs += ref_space.size

        return int(scipy.special.binom(n_orbs, n_elec[0]) * scipy.special.binom(n_orbs, n_elec[1]))


def mat_idx(site_idx: int, nx: int, ny: int) -> Tuple[int, int]:
        """
        this function returns x and y indices of a matrix

        example:
        >>> mat_idx(6, 4, 4)
        (1, 2)
        >>> mat_idx(9, 8, 2)
        (4, 1)
        """
        y = site_idx % nx
        x = int(math.floor(float(site_idx) / ny))
        return x, y


def near_nbrs(site_xy: Tuple[int, int], nx: int, ny: int) -> List[Tuple[int, int]]:
        """
        this function returns a list of nearest neighbour indices

        example:
        >>> near_nbrs((1, 2), 4, 4)
        [(0, 2), (2, 2), (1, 3), (1, 1)]
        >>> near_nbrs((4, 1), 8, 2)
        [(3, 1), (5, 1), (4, 0), (4, 0)]
        """
        up = ((site_xy[0] - 1) % nx, site_xy[1])
        down = ((site_xy[0] + 1) % nx, site_xy[1])
        left = (site_xy[0], (site_xy[1] + 1) % ny)
        right = (site_xy[0], (site_xy[1] - 1) % ny)
        return [up, down, left, right]


def write_file(order: Union[None, int], arr: np.ndarray, string: str) -> None:
        """
        this function writes a general restart file corresponding to input string
        """
        if order is None:
            np.save(os.path.join(RST, '{:}'.format(string)), arr)
        else:
            np.save(os.path.join(RST, '{:}_{:}'.format(string, order)), arr)


def read_file(order: int, string: str) -> np.ndarray:
        """
        this function reads a general restart file corresponding to input string
        """
        if order is None:
            return np.load(os.path.join(RST, '{:}.npy'.format(string)))
        else:
            return np.load(os.path.join(RST, '{:}_{:}.npy'.format(string, order)))


def natural_keys(txt: str) -> List[Union[int, str]]:
        """
        this function return keys to sort a string in human order (as alist.sort(key=natural_keys))
        see: http://nedbatchelder.com/blog/200712/human_sorting.html
        see: https://stackoverflow.com/questions/5967500/how-to-correctly-sort-a-string-with-a-number-inside

        example:
        >>> natural_keys('mbe_test_string')
        ['mbe_test_string']
        >>> natural_keys('mbe_test_string_1')
        ['mbe_test_string_', 1, '']
        """
        return [_convert(c) for c in re.split('(\d+)', txt)]


def _convert(txt: str) -> Union[int, str]:
        """
        this function converts strings with numbers in them

        example:
        >>> isinstance(_convert('string'), str)
        True
        >>> isinstance(_convert('1'), int)
        True
        """
        return int(txt) if txt.isdigit() else txt


if __name__ == "__main__":
    import doctest
    doctest.testmod()


back to top