https://github.com/mit-plv/fiat-crypto
Raw File
Tip revision: 6fea4500019df2d320e2c62fc25c4b03da72659c authored by Jason Gross on 03 July 2022, 02:52:06 UTC
Merge pull request #1310 from JasonGross/sp2019latest+strict-hints
Tip revision: 6fea450
generate_parameters.py

'''
EXAMPLES (handwritten):


# p256 - amd128
{
    "modulus"              : "2^256-2^224+2^192+2^96-1",
    "base"                 : "128",
    "sz"                   : "2",
    "bitwidth"             : "128",
    "montgomery"           : "true",
    "operations"           : ["fenz", "feadd", "femul", "feopp", "fesub"],
    "compiler"             : "gcc -fno-peephole2 `#GCC BUG 81300` -march=native -mbmi2 -mtune=native -std=gnu11 -O3 -flto -fomit-frame-pointer -fwrapv -Wno-attributes -Wno-incompatible-pointer-types -fno-strict-aliasing"
}

# p256 - amd64
{
    "modulus"              : "2^256-2^224+2^192+2^96-1",
    "base"                 : "64",
    "sz"                   : "4",
    "bitwidth"             : "64",
    "montgomery"           : "true",
    "operations"           : ["fenz", "feadd", "femul", "feopp", "fesub"],
    "compiler"             : "gcc -fno-peephole2 `#GCC BUG 81300` -march=native -mbmi2 -mtune=native -std=gnu11 -O3 -flto -fomit-frame-pointer -fwrapv -Wno-attributes -Wno-incompatible-pointer-types -fno-strict-aliasing"
}


# p448 - c64
{
    "modulus"          : "2^448-2^224-1",
    "base"             : "56",
    "goldilocks"       : "true",
    "sz"               : "8",
    "bitwidth"         : "64",
    "carry_chains"     : [[3, 7],
			  [0, 4, 1, 5, 2, 6, 3, 7],
			  [4, 0]],
    "coef_div_modulus" : "2",
    "operations"       : ["femul"]
}

# curve25519 - c64
{
    "modulus"          : "2^255-19",
    "base"             : "51",
    "sz"               : "5",
    "bitwidth"         : "64",
    "carry_chains"     : "default",
    "coef_div_modulus" : "2",
    "operations"       : ["femul", "fesquare", "freeze"],
    "compiler"         : "gcc -march=native -mbmi2 -mtune=native -std=gnu11 -O3 -flto -fomit-frame-pointer -fwrapv -Wno-attributes",
}

# curve25519 - c32
{
    "modulus"          : "2^255-19",
    "base"             : "25.5",
    "sz"               : "10",
    "bitwidth"         : "32",
    "carry_chains"     : "default",
    "coef_div_modulus" : "2",
    "operations"       : ["femul", "fesquare", "freeze"],
    "compiler"         : "gcc -march=native -mbmi2 -mtune=native -std=gnu11 -O3 -flto -fomit-frame-pointer -fwrapv -Wno-attributes",
}

'''

import math,json,sys,os,traceback,re,textwrap
from fractions import Fraction

CC = "clang -fbracket-depth=999999 -march=native -mbmi2 -mtune=native -std=gnu11 -O3 -flto -fuse-ld=lld -fomit-frame-pointer -fwrapv -Wno-attributes -fno-strict-aliasing"
CCX = "clang++ -fbracket-depth=999999 -march=native -mbmi2 -mtune=native -std=gnu++11 -O3 -flto -fuse-ld=lld -fomit-frame-pointer -fwrapv -Wno-attributes -fno-strict-aliasing"

# for montgomery
COMPILER_MONT = CC
COMPILERXX_MONT = CCX
# for solinas
COMPILER_SOLI = CC
COMPILERXX_SOLI = CCX
CUR_PATH = os.path.dirname(os.path.realpath(__file__))
JSON_DIRECTORY = os.path.join(CUR_PATH, "src/Specific/CurveParameters")
REMAKE_CURVES = os.path.join(JSON_DIRECTORY, 'remake_curves.sh')

class LimbPickingException(Exception): pass
class NonBase2Exception(Exception): pass
class UnexpectedPrimeException(Exception): pass

# given a string representing one term or "tap" in a prime, returns a pair of
# integers representing the weight and coefficient of that tap
#    "2 ^ y" -> [1, y]
#    "x * 2 ^ y" -> [x, y]
#    "x * y" -> [x*y,0]
#    "x" -> [x,0]
def parse_term(t) :
    if "*" not in t and "^" not in t:
        return [int(t),0]

    if "*" in t:
        if len(t.split("*")) > 2: # this occurs when e.g. [w - x * y] has been turned into [w + -1 * x * y]
            a1,a2,b = t.split("*")
            a = int(a1) * int(a2)
        else:
            a,b = t.split("*")
        if "^" not in b:
            return [int(a) * int(b),0]
    else:
        a,b = (1,t)

    b,e = b.split("^")
    if int(b) != 2:
        raise NonBase2Exception("Could not parse term, power with base other than 2: %s" %t)
    return [int(a),int(e)]


# expects prime to be a string and expressed as sum/difference of products of
# two with small coefficients (e.g. '2^448 - 2^224 - 1', '2^255 - 19')
def parse_prime(prime):
    prime = prime.replace("-", "+ -").replace(' ', '').replace('+-2^', '+-1*2^')
    terms = prime.split("+")
    return list(map(parse_term, terms))

# check that the parsed prime makes sense
def sanity_check(p):
    if not all([
        # are there at least 2 terms?
        len(p) > 1,
        # do all terms have 2 elements?
        all(map(lambda t:len(t) == 2, p)),
        # are terms are in order (most to least significant)?
        p == list(sorted(p,reverse=True,key=lambda t:t[1])),
        # does the least significant term have weight 2^0=1?
        p[-1][1] == 0,
        # are all the exponents positive and the coefficients nonzero?
        all(map(lambda t:t[0] != 0 and t[1] >= 0, p)),
        # is second-most-significant term negative?
        p[1][0] < 0,
        # are any exponents repeated?
        len(set(map(lambda t:t[1], p))) == len(p)]) :
        raise UnexpectedPrimeException("Parsed prime %s has unexpected format" %p)


def eval_numexpr(numexpr):
  # copying from https://stackoverflow.com/a/25437733/377022
  numexpr = re.sub(r"\.(?![0-9])", "", numexpr) # purge any instance of '.' not followed by a number
  return eval(numexpr, {'__builtins__':None})

def get_extra_compiler_params(q, base, bitwidth, sz):
    def log_wt(i):
        return int(math.ceil(sum(map(Fraction, map(str.strip, str(base).split('+')))) * i))
    q_int = eval_numexpr(q.replace('^', '**'))
    a24 = 12345 # TODO
    modulus_bytes = (q_int.bit_length()+7)//8
    limb_widths = repr('{%s}' % ','.join(str(int(log_wt(i + 1) - log_wt(i))) for i in range(sz)))
    defs = {
        'q_mpz' : repr(re.sub(r'2(\s*)\^(\s*)([0-9]+)', r'(1_mpz\1<<\2\3)', str(q))),
        'modulus_bytes_val' : repr(str(modulus_bytes)),
        'modulus_array' : repr('{%s}' % ','.join(reversed(list('0x%02x' % ((q_int >> 8*i)&0xff) for i in range(modulus_bytes))))),
        'a_minus_two_over_four_array' : repr('{%s}' % ','.join(reversed(list('0x%02x' % ((a24 >> 8*i)&0xff) for i in range(modulus_bytes))))),
        'a24_val' : repr(str(a24)),
        'a24_hex' : repr(hex(a24)),
        'bitwidth' : repr(str(bitwidth)),
        'modulus_limbs' : repr(str(sz)),
        'limb_weight_gaps_array' : limb_widths
    }
    return ' ' + ' '.join('-D%s=%s' % (k, v) for k, v in sorted(defs.items()))

def num_bits(p):
    return p[0][1]

def get_params_montgomery(prime, bitwidth):
    p = parse_prime(prime)
    sanity_check(p)
    sz = int(math.ceil(num_bits(p) / float(bitwidth)))
    return [{
            "modulus" : prime,
            "base" : str(bitwidth),
            "sz" : str(sz),
            "montgomery" : True,
            "operations" : ["fenz", "feadd", "femul", "feopp", "fesub"],
            "extra_files" : ["montgomery%s/fesquare.c" % str(bitwidth)],
            "compiler" : COMPILER_MONT + get_extra_compiler_params(prime, bitwidth, bitwidth, sz),
            "compilerxx" : COMPILERXX_MONT + get_extra_compiler_params(prime, bitwidth, bitwidth, sz)
            }]

def place(weight, nlimbs, wt):
    for i in range(nlimbs):
        if weight(i) <= wt and weight(i+1) > wt:
            return i
    return None

def solinas_reduce(p, pprods):
    out = []
    for wt, x in pprods:
        if wt >= num_bits(p):
            for coef, exp in p[1:]:
                out.append((wt - num_bits(p) + exp, -coef * x))
        else:
            out.append((wt, x))
    return out

# check if the suggested number of limbs will overflow when adding partial
# products after a multiplication and then doing solinas reduction
def overflow_free(p, bitwidth, nlimbs):
    # weight (exponent only)
    weight = lambda n : math.ceil(n * (num_bits(p) / nlimbs))
    # bit widths in canonical form
    width = lambda i : weight(i + 1) - weight(i)

    # num of bits in each term after 1 addition of things with bounds at 1.125 * width
    start = [(2**width(i))*1.125*2-1 for i in range(nlimbs)]

    # get partial products in (weight, # bits) pairs
    pp = [(weight(i) + weight(j), start[i] * start[j]) for i in range(nlimbs) for j in range(nlimbs)]

    # reduction step
    ppr = pp
    while max(ppr, key=lambda t:t[0])[0] >= num_bits(p):
        ppr = solinas_reduce(p, ppr)

    # accumulate partial products
    cols = [[] for _ in range(nlimbs)]
    for wt, x in ppr:
        i = place(weight, nlimbs, wt)
        if i == None:
            raise LimbPickingException("Could not place weight %s (%s limbs, p=%s)" %(wt, nlimbs, p))
        cols[i].append(x * (2**(wt - weight(i))))

    # add partial products together at each position
    final = [math.log2(sum(ls)) if sum(ls) > 0 else 0 for ls in cols]
    #print(nlimbs, list(map(lambda x: round(x,1), final)))

    result = all(map(lambda x:x < 2*bitwidth, final))
    return result

# given a parsed prime, pick out all plausible numbers of (unsaturated) limbs
def get_possible_limbs(p, bitwidth):
    # we want to leave enough bits unused to do a full solinas reduction
    # without carrying; the number of bits necessary is the sum of the bits in
    # the negative coefficients of p (other than the most significant digit)
    unused_bits = sum(map(lambda t: math.ceil(math.log(-t[0], 2)) if t[0] < 0 else 0, p[1:]))
    min_limbs = int(math.ceil(num_bits(p) / (bitwidth - unused_bits)))

    # don't search past 2x as many limbs as saturated representation; that's just wasteful
    result = list(filter(lambda n : overflow_free(p, bitwidth, n), range(min_limbs, 2*min_limbs)))
    # print("for prime %s, %s / %s limb choices were successful" %(p, len(result), min_limbs))
    return result

def is_goldilocks(p):
    return p[0][1] == 2 * p[1][1]

def format_base(numerator, denominator):
    if numerator % denominator == 0:
        base = int(numerator / denominator)
    else:
        base = Fraction(numerator=numerator, denominator=denominator)
        if base.denominator in (1, 2, 4, 5, 8, 10):
            base = float(base)
        else:
            base_int, base_frac = int(base), base - int(base)
            base = '%d + %s' % (base_int, str(base_frac))
    return base

# removes latest occurences, preserves order
def remove_duplicates(l):
    seen = []
    for x in l:
        if x not in seen:
            seen.append(x)
    return seen

def get_params_solinas(prime, bitwidth):
    p = parse_prime(prime)
    sanity_check(p)
    out = []
    l = get_possible_limbs(p, bitwidth)
    if len(l) == 0:
        raise LimbPickingException("Could not find a good number of limbs for prime %s and bitwidth %s" %(prime, bitwidth))
    # only use the top 2 choices
    for sz in l[:2]:
        base = format_base(num_bits(p), sz)

        # Uncomment to pretty-print primes/bases
        # print("  ".join(map(str, [prime, " "*(35-len(prime)), bitwidth, base, sz])))

        if len(p) > 2:
            # do interleaved carry chains, starting at where the taps are
            starts = [(int(t[1] / (num_bits(p) / sz)) - 1) % sz for t in p[1:]]
            chain2 = []
            for n in range(1,sz):
                for j in starts:
                    chain2.append((j + n) % sz)
            chain2 = remove_duplicates(chain2)
            chain3 = list(map(lambda x:(x+1)%sz,starts))
            carry_chains = [starts,chain2,chain3]
        else:
            carry_chains = "default"
        params = {
                "modulus": prime,
                "base" : str(base),
                "sz" : str(sz),
                "bitwidth" : bitwidth,
                "carry_chains" : carry_chains,
                "coef_div_modulus" : str(2),
                "operations"       : ["femul", "feadd", "fesub", "fesquare", "fecarry", "freeze"],
                "compiler"         : COMPILER_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz),
                "compilerxx"       : COMPILERXX_SOLI + get_extra_compiler_params(prime, base, bitwidth, sz)
                }
        if is_goldilocks(p):
            params["goldilocks"] = True
        out.append(params)
    return out

def write_if_changed(filename, contents):
    if os.path.isfile(filename):
        with open(filename, 'r') as f:
            old = f.read()
        if old == contents: return
    with open(filename, 'w') as f:
        f.write(contents)

def update_remake_curves(filename):
    with open(REMAKE_CURVES, 'r') as f:
        lines = f.readlines()
    new_line = '${MAKE} "$@" %s ../%s/\n' % (filename, filename[:-len('.json')])
    if new_line in lines: return
    if any(filename in line for line in lines):
        lines = [(line if filename not in line else new_line)
                 for line in lines]
    else:
        lines.append(new_line)
    write_if_changed(REMAKE_CURVES, ''.join(lines))

def format_json(params):
    return json.dumps(params, indent=4, separators=(',', ': '), sort_keys=True) + '\n'


def write_output(name, params):
    prime = params["modulus"]
    nlimbs = params["sz"]
    filename = (name + "_" + prime + "_" + nlimbs + "limbs" + ".json").replace("^","e").replace(" ","").replace("-","m").replace("+","p").replace("*","x")

    write_if_changed(os.path.join(JSON_DIRECTORY, filename),
                     format_json(params))
    update_remake_curves(filename)

def try_write_output(name, get_params, prime, bitwidth):
    try:
        all_params = get_params(prime, bitwidth)
        for params in all_params:
            write_output(name, params)
    except (LimbPickingException, NonBase2Exception, UnexpectedPrimeException) as e:
        print(e)
    except Exception as e:
        traceback.print_exc()

USAGE = "python generate_parameters.py input_file"
if __name__ == "__main__":
    if len(sys.argv) < 2:
        print(USAGE)
        sys.exit()
    f = open(sys.argv[1])
    for line in f:
        # skip comments and empty lines
        if line.strip().startswith("#") or len(line.strip()) == 0:
            continue
        prime = line.split("#")[0].strip() # remove trailing comments and trailing/leading whitespace
        try_write_output("montgomery32", get_params_montgomery, prime, 32)
        try_write_output("montgomery64", get_params_montgomery, prime, 64)
        try_write_output("solinas32", get_params_solinas, prime, 32)
        try_write_output("solinas64", get_params_solinas, prime, 64)
    f.close()
back to top