Revision 1e787978ca51b159fe3063548813fb3d41ec005f authored by mariana gama on 15 August 2023, 16:32:17 UTC, committed by mariana gama on 15 August 2023, 16:32:17 UTC
1 parent 2e676e8
Raw File
double_auction.mpc
# (C) 2017 University of Bristol. See License.txt
# (C) 2018 KU Leuven. See License.txt

#=======================================================================================#
import math

sfloat.plen = 16   # Length of exponent in bits
sfloat.vlen = 40   # Length of mantissa in bits
sfloat.kappa = 40  # Statistical security parameter for floats
print_ln("%s  %s  %s", sfloat.plen, sfloat.vlen, sfloat.kappa)

n = 100  # number of orders
m = 50  # number of exchange rates

liq0 = 100
liq1 = 200
rho0 = 0
rho1 = 0

frozen0 = MemValue(sint(0))
frozen1 = MemValue(sint(0))

# creating (not so) random bits
rand = sint.Array(n)
@for_range(n)
def loop_body(i):
    rand[i] = 1
#rand[5] = 0
#rand[6] = 0
#rand[9] = 0
#rand[8] = 0

vrand = sint(size=n)
vldms(n, vrand, rand.address)	
    
vone_m = sint(1, size=m)

print_ln("orders: %s, prices: %s", n, m)

#=======================================================================================#
# READ ORDERS FROM FILE + Input check
ids = sint.Array(n)
orders_sell = sint.Matrix(n, m)
orders_buy = sint.Matrix(n, m)

first_list = sint.Array(m+1)
first_list0 = sint.Array(m)
direction0 = sint.Array(m)
temp_sell = sint.Array(m)
temp_buy = sint.Array(m)
triples = sint.Array(m+1)
result = sint.Array(m+1)
test = MemValue(sint(0))

#start_timer(1)
open_channel_with_return(12)
none = sint.get_private_input_from(0,12)
none2 = sint.get_private_input_from(0,12)
@for_range(n)
def loop_body(i):
    test.write(0)
    ids[i] = sint.get_private_input_from(0,12)

    for j in range(m/2 + m%2):
        triples[2*j], triples[j*2+1], s_btemp = sint.get_random_triple()
    triples[m], s_atemp, s_btemp = sint.get_random_triple()

    @for_range(m)
    def loop_body(j):
        first_list[j] = sint.get_private_input_from(0,12)
    first_list[m] = sint.get_private_input_from(0,12)
 
    vtriples = sint(size=m+1)
    vvalues = sint(size=m+1)
    vldms(m+1, vtriples, triples.address)	
    vldms(m+1, vvalues, first_list.address)	
    
    vresult = (vvalues*vvalues - vvalues)
    vresult = vtriples * vresult
    vresult.store_in_mem(result.address)
    @for_range(m+1)
    def loop_body(j):
        test.write(test + result[j])


    # if order is good, convert order format
    if_then(test.reveal() != 0)
    print_ln("Bad order - %s", ids[i].reveal())

    else_then()
    @for_range(m)
    def loop_body(j):
        first_list0[j] = first_list[j] 
        direction0[j] = first_list[m]
    vlist0 = sint(size=m)
    vdir0 = sint(size=m)
    vldms(m, vlist0, first_list0.address)	
    vldms(m, vdir0, direction0.address)	

    vsell = vlist0 * vdir0  
    vbuy = vlist0 * (vone_m - vdir0)
    vsell.store_in_mem(temp_sell.address)
    vbuy.store_in_mem(temp_buy.address)

    @for_range(m)
    def loop_body(j):
       orders_sell[i][j] = temp_sell[j] 
       orders_buy[i][j] = temp_buy[j]
    end_if()

    
close_channel(12)
#stop_timer(1)
"""
print_ln()
print_ln("Inputs sell / buy")
@for_range(n)
def loop(i):
    @for_range(m)
    def loop(j): 
        print_str("%s ", orders_sell[i][j].reveal())
        print_str("%s    ", orders_buy[i][j].reveal())
    print_ln()
print_ln()
print_ln()
"""
#=======================================================================================#
# 10k VECTOR OPS


#=======================================================================================#
# DOUBLE AUCTION
def find_price():

    total_sell = sint.Array(m)
    total_buy = sint.Array(m)
    matched_vol = sint.Array(m)
    weights = sfloat.Array(m)

    # find matched volume for each price (utility)
    @for_range(n)
    def loop_body(i):
        @for_range(m)
        def loop_body(j):
            total_sell[j] = total_sell[j] + orders_sell[i][j]
            total_buy[j] = total_buy[j] + orders_buy[i][j]

    vtotal_sell = sint(size=m)
    vtotal_buy = sint(size=m)
    vldms(m, vtotal_sell, total_sell.address)	
    vldms(m, vtotal_buy, total_buy.address)	

    vcomp = (vtotal_sell > vtotal_buy)
    vsmall_vol = vcomp * vtotal_buy + (vone_m - vcomp) * vtotal_sell 
    vsmall_vol.store_in_mem(matched_vol.address)

    
    print_ln()
    #print_ln("# matched orders")
    #@for_range(m)
    #def loop(i):
    #    print_str("%s   ", matched_vol[i].reveal())
        #print_ln()
    #print_ln()
    #print_ln()
    

    # EPS = 4*ln(2)  define selection probability for each price according to utility and epsilon
    """
    float1 = sfloat(1)
    #print_ln("Cumulative weights eps=4*ln(2)")
    @for_range(m)
    def loop(i):
        weights[i] = sfloat(float1.v, float1.p + 2*matched_vol[i], 0, 0, err = 0, size = None)
        if_then(i>0)
        weights[i] = weights[i] + weights[i-1]
        end_if()
    """


    # EPS = 2*ln(2)  define selection probability for each price according to utility and epsilon
    #"""
    float1 = sfloat(1)
    #print_ln("Cumulative weights eps=2*ln(2)")
    @for_range(m)
    def loop(i):
        weights[i] = sfloat(float1.v, float1.p + matched_vol[i], 0, 0, err = 0, size = None)
        if_then(i>0)
        weights[i] = weights[i] + weights[i-1]
        end_if()
    #"""

    # EPS = ln(2)  define selection probability for each price according to utility and epsilon
    """
    d = 1
    float1 = sfloat(1)
    clear1 = sfloat(1)
    root = sfloat(1.41421356)
    #print_ln("Cumulative weights eps=ln(2) (d=1)")
    @for_range(m)
    def loop(i):
        aprime = sint()
        bprime = sint()
        AdvInteger.Trunc(aprime, matched_vol[i], 32, 1, 40, True)
        AdvInteger.Mod2(bprime, matched_vol[i], 32, 40, True)
        power = sfloat(float1.v, float1.p + aprime, 0, 0, err = 0, size = None)
        choice = clear1 + (root-clear1) * bprime
        weights[i] = power * choice
        if_then(i>0)
        weights[i] = weights[i] + weights[i-1]
        end_if()
    """ 

    # EPS = ln(2)/2  define selection probability for each price according to utility and epsilon
    """
    d = 2
    float1 = sfloat(1)
    clear1 = sfloat(1)
    root = sfloat(1.41421356)
    root3 = sfloat(1.25992105)
    root4 = sfloat(1.18920712) 
    #print_ln("Cumulative weights eps=ln(2)/2 (d=2)")
    @for_range(m)
    def loop(i):
        aprime = sint()
        bprime = sint()
        AdvInteger.Trunc(aprime, matched_vol[i], 32, 2, 40, True)
        AdvInteger.Mod2m(bprime, matched_vol[i], 32, 2, 40, True)
        power = sfloat(float1.v, float1.p + aprime, 0, 0, err = 0, size = None)
        choice = clear1*(bprime==0) + (root)*(bprime==1) + (root3)*(bprime==2) + (root4)*(bprime==3)
        weights[i] = power * choice
        if_then(i>0)
        weights[i] = weights[i] + weights[i-1]
        end_if()
    """ 




    # sample random number z
    #print_ln("Sampling z: ")
    bits = 80

    randint = get_random_int(bits)
    a1,a2,a3,a4 = floatingpoint.Int2FL(randint, 81, sfloat.vlen, sfloat.kappa)
    randfloat = sfloat(a1, a2, a3, a4, err = 0, size = None)
    
    z = sfloat(randfloat.v, randfloat.p-sint(bits), 0, 0, err = 0, size =None)
    z_norm = z * weights[m-1]


    # sample price
    #print_ln("Select weight: ")
    low = MemValue(cint(0))
    high = MemValue(cint(m-1))
    index = MemValue(cint(0))
    @do_while
    def loop():
        middle = low + high
        index.write((middle - middle%2)/2)
        c = (weights[index] < z_norm).reveal()
        if_then(c==1)
        low.write(index+1)
        else_then()
        high.write(index)
        end_if()
        return (low < high) == 1

    print_ln()
    print_ln("Chosen weight = %s", low)
    print_ln()

#=======================================================================================#
# RUN EXPERIMENTS
"""
print_ln()
print_ln("Matrix before")
@for_range(n)
def loop(i):
    @for_range(p)
    def loop(j):
        print_str("%s ", orders_matrix[i][j].reveal())
    print_ln()
print_ln()
print_ln()
"""



for i in range(10):
    start_timer(2)
    find_price()
    stop_timer(2)


"""
print_ln()
print_ln("Ouputs sell / buy")
@for_range(n)
def loop(i):
    print_str("%s   ", total_outsell[i].reveal())
    print_str("%s ", total_outbuy[i].reveal())
    print_ln()
print_ln()
print_ln()
"""






back to top