import os
import scipy as sp
from scipy import linalg
import h5py
import time
import sys
from ldpred import reporting
from ldpred import plinkfiles
from ldpred import util

def get_prs(genotype_file, rs_id_map, phen_map=None, only_score = False, verbose=False):
    plinkf = plinkfiles.plinkfile.PlinkFile(genotype_file)
    samples = plinkf.get_samples()

    if not only_score:
        # 1. Figure out indiv filter and get true phenotypes
        indiv_filter = sp.zeros(len(samples), dtype='bool8')
        true_phens = []
        iids = []
        if phen_map is not None:
            pcs = []
            sex = []
            covariates = []
            phen_iids = set(phen_map.keys())
            for samp_i, sample in enumerate(samples):
                if sample.iid in phen_iids:
                    indiv_filter[samp_i] = True
                    if 'pcs' in list(phen_map[sample.iid].keys()):
                    if 'sex' in list(phen_map[sample.iid].keys()):
                    if 'covariates' in list(phen_map[sample.iid].keys()):
            if len(pcs) > 0:
                assert len(pcs) == len(
                    true_phens), 'PC information missing for some individuals with phenotypes'
            if len(sex) > 0:
                assert len(sex) == len(
                    true_phens), 'Sex information missing for some individuals with phenotypes'
            if len(covariates) > 0:
                assert len(covariates) == len(
                    true_phens), 'Covariates missing for some individuals with phenotypes'
            for samp_i, sample in enumerate(samples):
                if sample.affection != 2:
                    indiv_filter[samp_i] = True
        num_individs = sp.sum(indiv_filter)
        assert num_individs > 0, 'Issues in parsing the phenotypes and/or PCs?'
        assert not sp.any(sp.isnan(
            true_phens)), 'Phenotypes appear to have some NaNs, or parsing failed.'
        if verbose:
            print('%d individuals have phenotype and genotype information.' % num_individs)
        iids = [sample.iid for sample in samples]
        num_individs = len(samples)
    num_non_matching_nts = 0
    num_flipped_nts = 0

    pval_derived_effects_prs = sp.zeros(num_individs)
    # If these indices are not in order then we place them in the right place
    # while parsing SNPs.
    if verbose:
        print('Iterating over BED file to calculate risk scores.')
    locus_list = plinkf.get_loci()
    num_loci = len(locus_list)
    loci_i = 0
    snp_i = 0
    duplicated_snps =0
    found_loci = set()
    for locus, row in zip(locus_list, plinkf):
        loci_i += 1
        if loci_i%100==0:
            sys.stdout.write('\r%0.2f%%' % (100.0 * (float(loci_i) / (num_loci-1.0))))
        upd_pval_beta = 0
        sid =
        if sid not in rs_id_map:
        if sid in found_loci:
        rs_info = rs_id_map[sid]
        if rs_info['upd_pval_beta'] == 0:

        # Check whether the nucleotides are OK, and potentially flip it.
        ss_nt = rs_info['nts']
        g_nt = [locus.allele1, locus.allele2]
        flip_nts = False
        os_g_nt = sp.array(
            [util.opp_strand_dict[g_nt[0]], util.opp_strand_dict[g_nt[1]]])
        if not (sp.all(g_nt == ss_nt) or sp.all(os_g_nt == ss_nt)):
            # Opposite strand nucleotides
            flip_nts = (g_nt[1] == ss_nt[0] and g_nt[0] == ss_nt[1]) or (
                os_g_nt[1] == ss_nt[0] and os_g_nt[0] == ss_nt[1])
            if flip_nts:
                upd_pval_beta = -rs_info['upd_pval_beta']
                num_flipped_nts += 1
                num_non_matching_nts += 1
            upd_pval_beta = rs_info['upd_pval_beta']

        # Parse SNP, and fill in the blanks if necessary.
        if only_score:
            snp = sp.array(row, dtype='int8')
            snp = sp.array(row, dtype='int8')[indiv_filter]
        bin_counts = row.allele_counts()
        if bin_counts[-1] > 0:
            mode_v = sp.argmax(bin_counts[:2])
            snp[snp == 3] = mode_v

        # Update scores and move on.
        pval_derived_effects_prs += -1*snp * upd_pval_beta
        assert not sp.any(sp.isnan(pval_derived_effects_prs)
                          ), 'Some individual weighted effects risk scores are NANs (not a number).  They are corrupted.'

        snp_i +=1
    perc_missing = 1 - float(len(found_loci))/float(len(rs_id_map))
    if perc_missing>0.01:
        raise Warning('More than %0.2f %% of variants for which weights were calculated were not found in '
                        'validation data.  This can lead to poor prediction accuracies.  Please consider '
                        'using the --vbim flag in the coord step to identify overlapping SNPs.'
    if not verbose:
        sys.stdout.write('\r%0.2f%%\n' % (100.0))

    if verbose:
        print('Number of non-matching NTs: %d' % num_non_matching_nts)
        print('Number of flipped NTs: %d' % num_flipped_nts)
    if not only_score:
        pval_eff_corr = sp.corrcoef(pval_derived_effects_prs, true_phens)[0, 1]
        pval_eff_r2 = pval_eff_corr ** 2

        if verbose:
            print('Current PRS correlation: %0.4f' % pval_eff_corr)
            print('Current PRS r2: %0.4f' % pval_eff_r2)
        ret_dict = {'pval_derived_effects_prs': pval_derived_effects_prs,
                    'true_phens': true_phens[:], 'iids': iids, 'prs_r2':pval_eff_r2, 'prs_corr':pval_eff_corr,
                    'num_snps':snp_i, 'num_non_matching_nts':num_non_matching_nts, 
                    'num_flipped_nts':num_flipped_nts, 'perc_missing':perc_missing, 
        if len(pcs) > 0:
            ret_dict['pcs'] = pcs
        if len(sex) > 0:
            ret_dict['sex'] = sex
        if len(covariates) > 0:
            ret_dict['covariates'] = covariates

        ret_dict = {'pval_derived_effects_prs': pval_derived_effects_prs,'iids': iids, 
                    'num_snps':snp_i, 'num_non_matching_nts':num_non_matching_nts, 
                    'num_flipped_nts':num_flipped_nts, 'perc_missing':perc_missing, 

    return ret_dict

def parse_phen_file(pf, pf_format, verbose=False, summary_dict=None):
    phen_map = {}
    num_phens_found = 0
    if pf != None:
        if verbose:
            print('Parsing Phenotypes')
        if pf_format == 'FAM':
            Individual's family ID ('FID')
            Individual's within-family ID ('IID'; cannot be '0')
            Within-family ID of father ('0' if father isn't in dataset)
            Within-family ID of mother ('0' if mother isn't in dataset)
            Sex code ('1' = male, '2' = female, '0' = unknown)
            Phenotype value ('1' = control, '2' = case, '-9'/'0'/non-numeric = missing data if case/control)
            with open(pf, 'r') as f:
                for line in f:
                    l = line.split()
                    iid = l[1]
                    sex = int(l[4])
                    phen = float(l[5])
                    if sex != 0 and phen != -9:
                        phen_map[iid] = {'phen': phen, 'sex': sex}
                        num_phens_found +=1

            summary_dict[1]={'name':'Phenotype file (plink format):','value':pf}
        if pf_format == 'STANDARD':
            IID   PHE
            with open(pf, 'r') as f:
                for line in f:
                    l = line.split()
                    iid = l[0]
                    phen = float(l[1])
                    phen_map[iid] = {'phen': phen}
                    num_phens_found +=1

            summary_dict[1]={'name':'Phenotype file (STANDARD format):','value':pf}

        if pf_format == 'LSTANDARD':
            FID IID PHE
            with open(pf, 'r') as f:
                # Read first line and see if it is header
                line = f.readline()
                l = line.split()
                    phen_map[iid] = {'phen':phen}
                    # Do nothing
                for line in f:
                    l = line.split()
                    phen_map[iid] = {'phen':phen}
            summary_dict[1]={'name':'Phenotype file (LSTANDARD format):','value':pf}
        elif pf_format == 'S2':
            IID Age Sex Height_Inches
            with open(pf, 'r') as f:
                for line in f:
                    l = line.split()
                    iid = l[0]
                    age = float(l[1])
                    if l[2] == 'Male':
                        sex = 1
                    elif l[2] == 'Female':
                        sex = 2
                        raise Exception('Sex missing')
                    phen = float(l[3])
                    phen_map[iid] = {'phen': phen, 'age': age, 'sex': sex}
                    num_phens_found +=1
            summary_dict[1]={'name':'Phenotype file (S2 format):','value':pf}

    print("Parsed %d phenotypes successfully"%num_phens_found)
    return phen_map

def parse_ldpred_res(file_name):
    rs_id_map = {}
    chrom    pos    sid    nt1    nt2    raw_beta    ldpred_inf_beta    ldpred_beta
    1, 798959, rs11240777, C, T, -1.1901e-02, 3.2443e-03, 2.6821e-04
    with open(file_name, 'r') as f:
        for line in f:
            l = line.split()
            chrom_str = l[0]
            chrom = int(chrom_str[6:])
            pos = int(l[1])
            rs_id = l[2].strip()
            nt1 = l[3].strip()
            nt2 = l[4].strip()
            nts = [nt1, nt2]
            raw_beta = float(l[5])
            upd_pval_beta = float(l[6])
            rs_id_map[rs_id] = {'chrom': chrom, 'pos': pos, 'nts': nts, 'raw_beta': raw_beta,
                                'upd_pval_beta': upd_pval_beta}
    return rs_id_map

def parse_pt_res(file_name):
    non_zero_chromosomes = set()
    rs_id_map = {}
    chrom    pos    sid    nt1    nt2    raw_beta    raw_pval_beta    upd_beta    upd_pval_beta
    1    798959    rs11240777    C    T    -1.1901e-02    -1.1901e-02    2.6821e-04    2.6821e-04
    with open(file_name, 'r') as f:
        for line in f:
            l = line.split()
            chrom_str = l[0]
            chrom = int(chrom_str[6:])
            pos = int(l[1])
            rs_id = l[2].strip()
            nt1 = l[3].strip()
            nt2 = l[4].strip()
            nts = [nt1, nt2]
            raw_beta = float(l[5])
            upd_pval_beta = float(l[8])
            if raw_beta != 0:
                rs_id_map[rs_id] = {'chrom': chrom, 'pos': pos, 'nts': nts, 'raw_beta': raw_beta,
                                    'upd_pval_beta': upd_pval_beta}

    return rs_id_map, non_zero_chromosomes

def write_scores_file(out_file, prs_dict, pval_derived_effects_prs, adj_pred_dict, 
                      output_regression_betas=False, weights_dict=None, verbose=False):
    num_individs = len(prs_dict['iids'])
    with open(out_file, 'w') as f:
        if verbose:
            print ('Writing polygenic scores to file %s'%out_file)
        out_str = 'IID, true_phens, PRS'
        if 'sex' in prs_dict:
            out_str = out_str + ', sex'
        if 'pcs' in prs_dict:
            pcs_str = ', '.join(['PC%d' % (1 + pc_i)
                                 for pc_i in range(len(prs_dict['pcs'][0]))])
            out_str = out_str + ', ' + pcs_str
        out_str += '\n'
        for i in range(num_individs):
            out_str = '%s, %0.6e, %0.6e' % (prs_dict['iids'][i], prs_dict['true_phens'][i],
            if 'sex' in prs_dict:
                out_str = out_str + ', %d' % prs_dict['sex'][i]
            if 'pcs' in prs_dict:
                pcs_str = ', '.join(map(str, prs_dict['pcs'][i]))
                out_str = out_str +', '+ pcs_str
            out_str += '\n'

    if len(list(adj_pred_dict.keys())) > 0:
        with open(out_file + '.adj', 'w') as f:
            adj_prs_labels = list(adj_pred_dict.keys())
            out_str = 'IID, true_phens, PRS, ' + \
                ', '.join(adj_prs_labels)
            out_str += '\n'
            for i in range(num_individs):
                out_str = '%s, %0.6e, %0.6e' % (prs_dict['iids'][i], prs_dict['true_phens'][i],
                for adj_prs in adj_prs_labels:
                    out_str += ', %0.4f' % adj_pred_dict[adj_prs][i]
                out_str += '\n'
    if output_regression_betas and weights_dict != None:
        hdf5file = out_file + '.weights.hdf5'
        if verbose:
            print ('Writing PRS regression weights to file %s'%hdf5file)
        oh5f = h5py.File(hdf5file, 'w')
        for k1 in list(weights_dict.keys()):
            kg = oh5f.create_group(k1)
            for k2 in weights_dict[k1]:
                kg.create_dataset(k2, data=sp.array(weights_dict[k1][k2]))

def write_only_scores_file(out_file, prs_dict, pval_derived_effects_prs):
    num_individs = len(prs_dict['iids'])
    with open(out_file, 'w') as f:
        print ('Writing polygenic scores to file %s'%out_file)
        out_str = 'IID, PRS \n'
        for i in range(num_individs):
            out_str = '%s, %0.6e\n' % (prs_dict['iids'][i],

def calc_risk_scores(bed_file, rs_id_map, phen_map, out_file=None, 
                     split_by_chrom=False, adjust_for_sex=False,
                     adjust_for_covariates=False, adjust_for_pcs=False, 
                     non_zero_chromosomes=None, only_score = False, 
                     verbose=False, summary_dict=None):
    if verbose:
        print('Parsing PLINK bed file: %s' % bed_file)
    if split_by_chrom:
        num_individs = len(phen_map)
        assert num_individs > 0, 'No individuals found.  Problems parsing the phenotype file?'
        pval_derived_effects_prs = sp.zeros(num_individs)

        for i in range(1, 23):
            if non_zero_chromosomes is None or i in non_zero_chromosomes:
                genotype_file = bed_file + '_%i_keep' % i
                if os.path.isfile(genotype_file + '.bed'):
                    if verbose:
                        print('Working on chromosome %d' % i)
                    prs_dict = get_prs(genotype_file, rs_id_map, phen_map, only_score=only_score, verbose=verbose)

                    pval_derived_effects_prs += prs_dict['pval_derived_effects_prs']
            elif verbose:
                    print('Skipping chromosome')

        prs_dict = get_prs(bed_file, rs_id_map, phen_map, only_score=only_score, verbose=verbose)
        num_individs = len(prs_dict['iids'])
        pval_derived_effects_prs = prs_dict['pval_derived_effects_prs']
    res_dict = {'num_snps':prs_dict['num_snps'], 'num_non_matching_nts':prs_dict['num_non_matching_nts'], 
                'num_flipped_nts':prs_dict['num_flipped_nts'], 'perc_missing':prs_dict['perc_missing'], 
                'duplicated_snps':prs_dict['duplicated_snps'], 'pred_r2': 0, 'corr_r2':0}
    if only_score:
        write_only_scores_file(out_file, prs_dict, pval_derived_effects_prs)
    elif sp.std(prs_dict['true_phens'])==0:
        if verbose:
            print('No variance left to explain in phenotype.')
        # Report prediction accuracy
        assert len(phen_map) > 0, 'No individuals found.  Problems parsing the phenotype file?'
        # Store covariate weights, slope, etc.
        weights_dict = {}
        # Store Adjusted predictions
        adj_pred_dict = {}

        #If there is no prediction, then output 0s.
        if sp.std(pval_derived_effects_prs)==0:

            weights_dict['unadjusted'] = {'Intercept': 0, 'ldpred_prs_effect': 0}
            pval_eff_corr = sp.corrcoef(pval_derived_effects_prs, prs_dict['true_phens'])[0, 1]
            pval_eff_r2 = pval_eff_corr ** 2
            res_dict['pred_r2'] = pval_eff_r2
            res_dict['corr_r2'] = pval_eff_corr
            pval_derived_effects_prs.shape = (len(pval_derived_effects_prs), 1)
            true_phens = sp.array(prs_dict['true_phens'])
            true_phens.shape = (len(true_phens), 1)
            # Direct effect
            Xs = sp.hstack([pval_derived_effects_prs, sp.ones((len(true_phens), 1))])
            (betas, rss00, r, s) = linalg.lstsq(
                sp.ones((len(true_phens), 1)), true_phens)
            (betas, rss, r, s) = linalg.lstsq(Xs, true_phens)
            pred_r2 = 1 - rss / rss00
            weights_dict['unadjusted'] = {
                'Intercept': betas[1][0], 'ldpred_prs_effect': betas[0][0]}
            if verbose:
                print('PRS trait correlation: %0.4f  R2: %0.4f' % (pval_eff_corr,pred_r2))
            # Adjust for sex
            if adjust_for_sex and 'sex' in prs_dict and len(prs_dict['sex']) > 0:
                sex = sp.array(prs_dict['sex'])
                sex.shape = (len(sex), 1)
                (betas, rss0, r, s) = linalg.lstsq(
                    sp.hstack([sex, sp.ones((len(true_phens), 1))]), true_phens)
                Xs = sp.hstack([pval_derived_effects_prs, sex,
                                sp.ones((len(true_phens), 1))])
                (betas, rss_pd, r, s) = linalg.lstsq(Xs, true_phens)
                weights_dict['sex_adj'] = {
                    'Intercept': betas[2][0], 'ldpred_prs_effect': betas[0][0], 'sex': betas[1][0]}
                if verbose:
                    print('Fitted effects (betas) for PRS, sex, and intercept on true phenotype:', betas)
                adj_pred_dict['sex_prs'] =, betas)
                pred_r2 = 1 - rss_pd / rss0
                print('Variance explained (Pearson R2) by PRS adjusted for Sex: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['Sex_adj_pred_r2'] = pred_r2
                pred_r2 = 1 - rss_pd / rss00
                print('Variance explained (Pearson R2) by PRS + Sex : %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['Sex_adj_pred_r2+Sex'] = pred_r2
            # Adjust for PCs
            if adjust_for_pcs and 'pcs' in prs_dict and len(prs_dict['pcs']) > 0:
                pcs = prs_dict['pcs']
                (betas, rss0, r, s) = linalg.lstsq(
                    sp.hstack([pcs, sp.ones((len(true_phens), 1))]), true_phens)
                Xs = sp.hstack([pval_derived_effects_prs,
                                sp.ones((len(true_phens), 1)), pcs])
                (betas, rss_pd, r, s) = linalg.lstsq(Xs, true_phens)
                weights_dict['pc_adj'] = {
                    'Intercept': betas[1][0], 'ldpred_prs_effect': betas[0][0], 'pcs': betas[2][0]}
                adj_pred_dict['pc_prs'] =, betas)
                pred_r2 = 1 - rss_pd / rss0
                print('Variance explained (Pearson R2) by PRS adjusted for PCs: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['PC_adj_pred_r2'] = pred_r2
                pred_r2 = 1 - rss_pd / rss00
                print('Variance explained (Pearson R2) by PRS + PCs: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['PC_adj_pred_r2+PC'] = pred_r2
                # Adjust for both PCs and Sex
                if adjust_for_sex and 'sex' in prs_dict and len(prs_dict['sex']) > 0:
                    sex = sp.array(prs_dict['sex'])
                    sex.shape = (len(sex), 1)
                    (betas, rss0, r, s) = linalg.lstsq(
                        sp.hstack([sex, pcs, sp.ones((len(true_phens), 1))]), true_phens)
                    Xs = sp.hstack([pval_derived_effects_prs, sex,
                                    sp.ones((len(true_phens), 1)), pcs])
                    (betas, rss_pd, r, s) = linalg.lstsq(Xs, true_phens)
                    weights_dict['sex_pc_adj'] = {
                        'Intercept': betas[2][0], 'ldpred_prs_effect': betas[0][0], 'sex': betas[1][0], 'pcs': betas[3][0]}
                    adj_pred_dict['sex_pc_prs'] =, betas)
                    pred_r2 = 1 - rss_pd / rss0
                    print('Variance explained (Pearson R2) by PRS adjusted for PCs and Sex: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                    res_dict['PC_Sex_adj_pred_r2'] = pred_r2
                    pred_r2 = 1 - rss_pd / rss00
                    print('Variance explained (Pearson R2) by PRS+PCs+Sex: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                    res_dict['PC_Sex_adj_pred_r2+PC_Sex'] = pred_r2
            # Adjust for covariates
            if adjust_for_covariates and 'covariates' in prs_dict and len(prs_dict['covariates']) > 0:
                covariates = prs_dict['covariates']
                (betas, rss0, r, s) = linalg.lstsq(
                    sp.hstack([covariates, sp.ones((len(true_phens), 1))]), true_phens)
                Xs = sp.hstack([pval_derived_effects_prs, covariates,
                                sp.ones((len(true_phens), 1))])
                (betas, rss_pd, r, s) = linalg.lstsq(Xs, true_phens)
                adj_pred_dict['cov_prs'] =, betas)
                pred_r2 = 1 - rss_pd / rss0
                print('Variance explained (Pearson R2) by PRS adjusted for Covariates: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['Cov_adj_pred_r2'] = pred_r2
                pred_r2 = 1 - rss_pd / rss00
                print('Variance explained (Pearson R2) by PRS + Cov: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['Cov_adj_pred_r2+Cov'] = pred_r2
                if adjust_for_pcs and 'pcs' in prs_dict and len(prs_dict['pcs']) and 'sex' in prs_dict and len(prs_dict['sex']) > 0:
                    pcs = prs_dict['pcs']
                    sex = sp.array(prs_dict['sex'])
                    sex.shape = (len(sex), 1)
                    (betas, rss0, r, s) = linalg.lstsq(
                        sp.hstack([covariates, sex, pcs, sp.ones((len(true_phens), 1))]), true_phens)
                    Xs = sp.hstack([pval_derived_effects_prs, covariates,
                                    sex, pcs, sp.ones((len(true_phens), 1))])
                    (betas, rss_pd, r, s) = linalg.lstsq(Xs, true_phens)
                    adj_pred_dict['cov_sex_pc_prs'] =, betas)
                    pred_r2 = 1 - rss_pd / rss0
                    print('Variance explained (Pearson R2) by PRS adjusted for Cov+PCs+Sex: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                    res_dict['Cov_PC_Sex_adj_pred_r2'] = pred_r2
                    pred_r2 = 1 - rss_pd / rss00
                    print('Variance explained (Pearson R2) by PRS+Cov+PCs+Sex: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                    res_dict['Cov_PC_Sex_adj_pred_r2+Cov_PC_Sex'] = pred_r2
            # Now calibration
            y_norm = (true_phens - sp.mean(true_phens)) / sp.std(true_phens)
            denominator =, pval_derived_effects_prs)
            numerator =, y_norm)
            regression_slope = (numerator / denominator)[0][0]
            if verbose:
                print('The slope for predictions with weighted effects is: %0.4f'% regression_slope)
        num_individs = len(prs_dict['pval_derived_effects_prs'])

        # Write PRS out to file.
        if out_file != None:
            write_scores_file(out_file, prs_dict, pval_derived_effects_prs, adj_pred_dict, 
                              weights_dict=weights_dict, verbose=verbose)

    return res_dict

def parse_covariates(p_dict,phen_map,summary_dict,verbose):
    with open(p_dict['cov_file'], 'r') as f:
        num_missing = 0
        for line in f:
            l = line.split()
            iid = l[0]
            if iid in phen_map:
                covariates = list(map(float, l[1:]))
                phen_map[iid]['covariates'] = covariates
                num_missing += 1
        if num_missing > 0:
            summary_dict[2.1]={'name':'Individuals w missing covariate information:','value':num_missing}
            if verbose:
                print('Unable to find %d iids in phen file!' % num_missing)
    summary_dict[2]={'name':'Parsed covariates file:','value':p_dict['cov_file']}

def parse_pcs(p_dict,phen_map,summary_dict,verbose):
    with open(p_dict['pcs_file'], 'r') as f:
        num_missing = 0
        for line in f:
            l = line.split()
            iid = l[1]
            if iid in phen_map:
                pcs = list(map(float, l[2:]))
                phen_map[iid]['pcs'] = pcs
                num_missing += 1
        if num_missing > 0:
            summary_dict[3.1]={'name':'Individuals w missing PCs:','value':num_missing}
            if verbose:
                print('Unable to find %d iids in phen file!' % num_missing)
    summary_dict[3]={'name':'Parsed PCs file:','value':p_dict['pcs_file']}

def main(p_dict):
    assert p_dict['summary_file'] is None or not p_dict['only_score'], 'Prediction summary file cannot be produced when the --only-score flag is set.'
    summary_dict = {}
    non_zero_chromosomes = set()
    verbose = p_dict['debug']

    t0 = time.time()

    summary_dict[0]={'name':'Validation genotype file (prefix):','value':p_dict['gf']}
    summary_dict[0.1]={'name':'Input weight file(s) (prefix):','value':p_dict['rf']}
    summary_dict[0.2]={'name':'Output scores file(s) (prefix):','value':p_dict['out']}


    if not p_dict['only_score']:
        summary_dict[0.9]={'name':'dash', 'value':'Phenotypes'}
        if verbose:
            print('Parsing phenotypes')
        if p_dict['pf'] is None:
            if p_dict['gf'] is not None:
                phen_map = parse_phen_file(p_dict['gf'] + '.fam', 'FAM', verbose=verbose, summary_dict=summary_dict)
                raise Exception('Validation phenotypes were not found.')
            phen_map = parse_phen_file(p_dict['pf'], p_dict['pf_format'], verbose=verbose, summary_dict=summary_dict)
        t1 = time.time()
        t = (t1 - t0)
        summary_dict[1.1]={'name':'Individuals with phenotype information:','value':len(phen_map)}
        summary_dict[1.2]={'name':'Running time for parsing phenotypes:','value':'%d min and %0.2f secs'% (t / 60, t % 60)}
        if p_dict['cov_file'] != None:
            if verbose:
                print('Parsing additional covariates')
            parse_covariates(p_dict, phen_map, summary_dict, verbose)

        if p_dict['pcs_file']:
            if verbose:
                print('Parsing PCs')

        num_individs = len(phen_map)
        assert num_individs > 0, 'No phenotypes were found!'
        phen_map = None

    t0 = time.time()
    prs_file_is_missing = True
    res_dict = {}
    if p_dict['rf_format'] == 'LDPRED' or p_dict['rf_format']=='ANY':
        weights_file = '%s_LDpred-inf.txt' % (p_dict['rf'])
        if os.path.isfile(weights_file):
            print('Calculating LDpred-inf risk scores')
            rs_id_map = parse_ldpred_res(weights_file)
            out_file = '%s_LDpred-inf.txt' % (p_dict['out'])
            res_dict['LDpred_inf'] = calc_risk_scores(p_dict['gf'], rs_id_map, phen_map, out_file=out_file, 
                             verbose=verbose, summary_dict=summary_dict)
            if not p_dict['only_score']:
                summary_dict[5.2]={'name':'LDpred_inf (unadjusted) Pearson R2:','value':'%0.4f'%res_dict['LDpred_inf']['pred_r2']}
            prs_file_is_missing = False
        best_ldpred_pred_r2 = 0
        best_p = None
        for p in p_dict['f']:
            weights_file = '%s_LDpred_p%0.4e.txt' % (p_dict['rf'], p)
            if os.path.isfile(weights_file):
                print('Calculating LDpred risk scores using f=%0.3e' % p)
                rs_id_map = parse_ldpred_res(weights_file)
                out_file = '%s_LDpred_p%0.4e.txt' % (p_dict['out'], p)
                method_str = 'LDpred_p%0.4e' % (p)
                res_dict[method_str] = calc_risk_scores(p_dict['gf'], rs_id_map, phen_map, out_file=out_file,
                                                        verbose=verbose, summary_dict=summary_dict)
                if len(res_dict[method_str]) and (res_dict[method_str]['pred_r2']) >best_ldpred_pred_r2:
                    best_ldpred_pred_r2 = res_dict[method_str]['pred_r2']
                    best_p = p
        if best_ldpred_pred_r2>0 and not p_dict['only_score']:         
            summary_dict[5.3]={'name':'Best LDpred (f=%0.2e) (unadjusted) R2:'%(best_p),'value':'%0.4f'%best_ldpred_pred_r2}

        best_ldpred_fast_pred_r2 = 0
        best_p = None
        for p in p_dict['f']:
            weights_file = '%s_LDpred_fast_p%0.4e.txt' % (p_dict['rf'], p)
            if os.path.isfile(weights_file):
                print('Calculating LDpred-fast risk scores using f=%0.3e' % p)
                rs_id_map = parse_ldpred_res(weights_file)
                out_file = '%s_LDpred_fast_p%0.4e.txt' % (p_dict['out'], p)
                method_str = 'LDpred_fast_p%0.4e' % (p)
                res_dict[method_str] = calc_risk_scores(p_dict['gf'], rs_id_map, phen_map, out_file=out_file,
                                                        verbose=verbose, summary_dict=summary_dict)
                if len(res_dict[method_str]) and (res_dict[method_str]['pred_r2']) >best_ldpred_fast_pred_r2:
                    best_ldpred_fast_pred_r2 = res_dict[method_str]['pred_r2']
                    best_p = p
        if best_ldpred_fast_pred_r2>0 and not p_dict['only_score']:         
            summary_dict[5.4]={'name':'Best LDpred-fast (f=%0.2e) (unadjusted) R2:'%(best_p),'value':'%0.4f'%best_ldpred_fast_pred_r2}

        # Plot results?

    if p_dict['rf_format'] == 'P+T' or p_dict['rf_format']=='ANY':

        best_pt_pred_r2 = 0
        best_t = None
        best_r2 = None
        for max_r2 in p_dict['r2']:
            for p_thres in p_dict['p']:
                weights_file = '%s_P+T_r%0.2f_p%0.4e.txt' % (p_dict['rf'], max_r2, p_thres)
                if os.path.isfile(weights_file):
                    print('Calculating P+T risk scores using p-value threshold of %0.3e, and r2 threshold of %0.2f' % (p_thres, max_r2))
                    rs_id_map, non_zero_chromosomes = parse_pt_res(weights_file)
                    if len(rs_id_map)>0:
                        out_file = '%s_P+T_r%0.2f_p%0.4e.txt' % (p_dict['out'], max_r2, p_thres)
                        method_str = 'P+T_r%0.2f_p%0.4e' % (max_r2,p_thres)
                        res_dict[method_str] = calc_risk_scores(p_dict['gf'], rs_id_map, phen_map, out_file=out_file,
                                                                verbose=verbose, summary_dict=summary_dict)
                        if len(res_dict[method_str]) and (res_dict[method_str]['pred_r2']) >best_pt_pred_r2:
                            best_pt_pred_r2 = res_dict[method_str]['pred_r2']
                            best_t = p_thres
                            best_r2 = max_r2
                        print('No SNPs found with p-values below the given threshold.')
        if best_pt_pred_r2>0 and not p_dict['only_score']:                
            summary_dict[5.5]={'name':'Best P+T (r2=%0.2f, p=%0.2e) (unadjusted) R2:'%(best_r2, best_t),'value':'%0.4f'%best_pt_pred_r2}

    # Plot results?
    assert not prs_file_is_missing, 'No SNP weights file was found.  A prefix to these should be provided via the --rf flag. Note that the prefix should exclude the _LDpred_.. extension or file ending. '

    res_summary_file = p_dict['summary_file']
    if res_summary_file is not None and not p_dict['only_score']:
        with open(res_summary_file,'w') as f:
            if verbose:
                print ('Writing Results Summary to file %s'%res_summary_file)
            out_str = 'Pred_Method    Pred_corr    Pred_R2    SNPs_used\n'
            for method_str in sorted(res_dict):
                out_str = '%s    %0.4f    %0.4f    %i\n'%(method_str, res_dict[method_str]['corr_r2'], res_dict[method_str]['pred_r2'], res_dict[method_str]['num_snps'])
    #Identifying the best prediction
    if not p_dict['only_score']:
        best_pred_r2 = 0
        best_method_str = None
        for method_str in res_dict:
            if len(res_dict[method_str]) and (res_dict[method_str]['pred_r2']) >best_pred_r2:
                best_pred_r2 = res_dict[method_str]['pred_r2']
                best_method_str = method_str
        if best_method_str is not None:
            print('The highest (unadjusted) Pearson R2 was %0.4f, and provided by %s'%(best_pred_r2,best_method_str))
            summary_dict[5.99]={'name':'dash', 'value':'Optimal polygenic score'}
            summary_dict[6]={'name':'Method with highest (unadjusted) Pearson R2:','value':best_method_str}
            summary_dict[6.1]={'name':'Best (unadjusted) Pearson R2:','value':'%0.4f'%best_pred_r2}
            if verbose:
                summary_dict[6.2]={'name':'Number of SNPs used','value':'%d'%res_dict[best_method_str]['num_snps']}
                summary_dict[6.3]={'name':'Number of SNPs flipped','value':'%d'%res_dict[best_method_str]['num_flipped_nts']}
                summary_dict[6.4]={'name':'Fraction of SNPs not found in validation data','value':'%0.4f'%res_dict[best_method_str]['perc_missing']}
                summary_dict[6.5]={'name':'Number of duplicated SNPs','value':'%d'%res_dict[best_method_str]['duplicated_snps']}
                summary_dict[6.6]={'name':'Number of non-matching nucleotides SNPs','value':'%d'%res_dict[best_method_str]['num_non_matching_nts']}
    t1 = time.time()
    t = (t1 - t0)
    summary_dict[4.9]={'name':'dash', 'value':'Scoring'}
    summary_dict[5.9]={'name':'Running time for calculating scores:','value':'%d min and %0.2f secs'% (t / 60, t % 60)}

    if prs_file_is_missing:
        print('SNP weights files were not found.  This could be due to a mis-specified --rf flag, or other issues.')
    reporting.print_summary(summary_dict,'Scoring Summary')
