https://github.com/ShikamaruZhang/MANN
Raw File
Tip revision: 914f5ddb06a3853f2f60a09d156702154b084cb3 authored by HE ZHANG on 27 July 2018, 20:03:31 UTC
change-index
Tip revision: 914f5dd
ExpertWeights.py
"""
Class of ExpertWeights
"""

import numpy as np
import tensorflow as tf

class ExpertWeights(object):
    def __init__(self, rng, shape , name):
        """rng"""
        self.initialRNG   = rng
        
        """shape"""
        self.weight_shape =  shape                    #4/8 * out * in
        self.bias_shape   =  (shape[0],shape[1],1)    #4/8 * out * 1
        
        """alpha and beta"""
        self.alpha        =  tf.Variable(self.initial_alpha(), name=name+'alpha') 
        self.beta         =  tf.Variable(self.initial_beta(),  name=name+'beta') 
    
        
    """initialize parameters for experts i.e. alpha and beta"""
    def initial_alpha_np(self):
        shape = self.weight_shape
        rng   = self.initialRNG
        alpha_bound = np.sqrt(6. / np.prod(shape[-2:]))
        alpha = np.asarray(
            rng.uniform(low=-alpha_bound, high=alpha_bound, size=shape),
            dtype=np.float32)
        return alpha
    
    def initial_alpha(self):
        alpha = self.initial_alpha_np()
        return tf.convert_to_tensor(alpha, dtype = tf.float32)
    
    def initial_beta(self):
        return tf.zeros(self.bias_shape, tf.float32)
    
    def get_NNweight(self, controlweights, batch_size):  
        a = tf.expand_dims(self.alpha, 1)                           #4*out*in   -> 4*1*out*in
        a = tf.tile(a, [1,batch_size,1,1])                          #4*1*out*in -> 4*?*out*in
        w = tf.expand_dims(tf.expand_dims(controlweights, -1), -1)  #4*?        -> 4*?*1*1   
        r = w * a                                                   #4*?*1*1 m 4*?*out*in
        return tf.reduce_sum(r , axis = 0)                          #?*out*in
        
        
    def get_NNbias(self, controlweights, batch_size):
        b = tf.expand_dims(self.beta, 1)                            #4*out*1   -> 4*1*out*1
        b = tf.tile(b, [1,batch_size,1,1])                          #4*1*out*1 -> 4*?*out*1
        w = tf.expand_dims(tf.expand_dims(controlweights, -1), -1)  #4*?        -> 4*?*1*1  
        r = w * b                                                   #4*?*1*1 m 4*?*out*1
        return tf.reduce_sum(r , axis = 0)                          #?*out*1



def save_EP(alpha, beta, filename, num_experts):
    for i in range(len(alpha)):
        for j in range(num_experts):
            a = alpha[i][j]
            b = beta[i][j]
            a.tofile(filename+'/cp%0i_a%0i.bin' % (i,j))
            b.tofile(filename+'/cp%0i_b%0i.bin' % (i,j))



    
    
    
    
"""
def regularization_penalty(alpha, gamma):
    number_alpha = len(alpha)
    penalty = 0
    for i in range(number_alpha):
        penalty += tf.reduce_mean(tf.abs(alpha[i]))
    return gamma * penalty / number_alpha
"""

back to top