https://github.com/bvilhjal/ldpred
Raw File
Tip revision: aa6b27c8ba9b88e77731ca3ccceb585f7f2362fc authored by Bjarni J. Vilhjalmsson on 21 November 2019, 07:31:08 UTC
Merge pull request #126 from foresitelabs/master
Tip revision: aa6b27c
validate.py
import os
import scipy as sp
from scipy import linalg
import h5py
import time
import sys
from ldpred import reporting
from ldpred import plinkfiles
from ldpred import util

def get_prs(genotype_file, rs_id_map, phen_map=None, only_score = False, verbose=False):
    plinkf = plinkfiles.plinkfile.PlinkFile(genotype_file)
    samples = plinkf.get_samples()

    if not only_score:
        # 1. Figure out indiv filter and get true phenotypes
        indiv_filter = sp.zeros(len(samples), dtype='bool8')
        true_phens = []
        iids = []
        if phen_map is not None:
            pcs = []
            sex = []
            covariates = []
            phen_iids = set(phen_map.keys())
            for samp_i, sample in enumerate(samples):
                if sample.iid in phen_iids:
                    indiv_filter[samp_i] = True
                    true_phens.append(phen_map[sample.iid]['phen'])
                    iids.append(sample.iid)
                    if 'pcs' in list(phen_map[sample.iid].keys()):
                        pcs.append(phen_map[sample.iid]['pcs'])
                    if 'sex' in list(phen_map[sample.iid].keys()):
                        sex.append(phen_map[sample.iid]['sex'])
                    if 'covariates' in list(phen_map[sample.iid].keys()):
                        covariates.append(phen_map[sample.iid]['covariates'])
            if len(pcs) > 0:
                assert len(pcs) == len(
                    true_phens), 'PC information missing for some individuals with phenotypes'
            if len(sex) > 0:
                assert len(sex) == len(
                    true_phens), 'Sex information missing for some individuals with phenotypes'
            if len(covariates) > 0:
                assert len(covariates) == len(
                    true_phens), 'Covariates missing for some individuals with phenotypes'
        else:
            for samp_i, sample in enumerate(samples):
                if sample.affection != 2:
                    indiv_filter[samp_i] = True
                    true_phens.append(sample.affection)
                    iids.append(sample.iid)
    
        num_individs = sp.sum(indiv_filter)
        assert num_individs > 0, 'Issues in parsing the phenotypes and/or PCs?'
    
        assert not sp.any(sp.isnan(
            true_phens)), 'Phenotypes appear to have some NaNs, or parsing failed.'
    
        if verbose:
            print('%d individuals have phenotype and genotype information.' % num_individs)
    else:
        iids = [sample.iid for sample in samples]
        num_individs = len(samples)
    num_non_matching_nts = 0
    num_flipped_nts = 0

    pval_derived_effects_prs = sp.zeros(num_individs)
    # If these indices are not in order then we place them in the right place
    # while parsing SNPs.
    if verbose:
        print('Iterating over BED file to calculate risk scores.')
    locus_list = plinkf.get_loci()
    num_loci = len(locus_list)
    loci_i = 0
    snp_i = 0
    duplicated_snps =0
    found_loci = set()
    for locus, row in zip(locus_list, plinkf):
        loci_i += 1
        if loci_i%100==0:
            sys.stdout.write('\r%0.2f%%' % (100.0 * (float(loci_i) / (num_loci-1.0))))
            sys.stdout.flush()            
        upd_pval_beta = 0
        sid = locus.name
        if sid not in rs_id_map:
            continue
        if sid in found_loci:
            duplicated_snps+=1
            continue
        
        found_loci.add(locus.name)
        rs_info = rs_id_map[sid]
        if rs_info['upd_pval_beta'] == 0:
            continue

        # Check whether the nucleotides are OK, and potentially flip it.
        ss_nt = rs_info['nts']
        g_nt = [locus.allele1, locus.allele2]
        flip_nts = False
        os_g_nt = sp.array(
            [util.opp_strand_dict[g_nt[0]], util.opp_strand_dict[g_nt[1]]])
        if not (sp.all(g_nt == ss_nt) or sp.all(os_g_nt == ss_nt)):
            # Opposite strand nucleotides
            flip_nts = (g_nt[1] == ss_nt[0] and g_nt[0] == ss_nt[1]) or (
                os_g_nt[1] == ss_nt[0] and os_g_nt[0] == ss_nt[1])
            if flip_nts:
                upd_pval_beta = -rs_info['upd_pval_beta']
                num_flipped_nts += 1
            else:
                num_non_matching_nts += 1
                continue
        else:
            upd_pval_beta = rs_info['upd_pval_beta']

        # Parse SNP, and fill in the blanks if necessary.
        if only_score:
            snp = sp.array(row, dtype='int8')
        else:
            snp = sp.array(row, dtype='int8')[indiv_filter]
        bin_counts = row.allele_counts()
        if bin_counts[-1] > 0:
            mode_v = sp.argmax(bin_counts[:2])
            snp[snp == 3] = mode_v

        # Update scores and move on.
        pval_derived_effects_prs += -1*snp * upd_pval_beta
        assert not sp.any(sp.isnan(pval_derived_effects_prs)
                          ), 'Some individual weighted effects risk scores are NANs (not a number).  They are corrupted.'

        snp_i +=1
    
    perc_missing = 1 - float(len(found_loci))/float(len(rs_id_map))
    if perc_missing>0.01:
        raise Warning('More than %0.2f %% of variants for which weights were calculated were not found in '
                        'validation data.  This can lead to poor prediction accuracies.  Please consider '
                        'using the --vbim flag in the coord step to identify overlapping SNPs.'
                        '\n'%(perc_missing*100)) 
    
    plinkf.close()
    if not verbose:
        sys.stdout.write('\r%0.2f%%\n' % (100.0))
        sys.stdout.flush()            

    if verbose:
        print('Number of non-matching NTs: %d' % num_non_matching_nts)
        print('Number of flipped NTs: %d' % num_flipped_nts)
    if not only_score:
        pval_eff_corr = sp.corrcoef(pval_derived_effects_prs, true_phens)[0, 1]
        pval_eff_r2 = pval_eff_corr ** 2

        if verbose:
            print('Current PRS correlation: %0.4f' % pval_eff_corr)
            print('Current PRS r2: %0.4f' % pval_eff_r2)
    
        ret_dict = {'pval_derived_effects_prs': pval_derived_effects_prs,
                    'true_phens': true_phens[:], 'iids': iids, 'prs_r2':pval_eff_r2, 'prs_corr':pval_eff_corr,
                    'num_snps':snp_i, 'num_non_matching_nts':num_non_matching_nts, 
                    'num_flipped_nts':num_flipped_nts, 'perc_missing':perc_missing, 
                    'duplicated_snps':duplicated_snps}
    
        if len(pcs) > 0:
            ret_dict['pcs'] = pcs
        if len(sex) > 0:
            ret_dict['sex'] = sex
        if len(covariates) > 0:
            ret_dict['covariates'] = covariates

    else:    
        ret_dict = {'pval_derived_effects_prs': pval_derived_effects_prs,'iids': iids, 
                    'num_snps':snp_i, 'num_non_matching_nts':num_non_matching_nts, 
                    'num_flipped_nts':num_flipped_nts, 'perc_missing':perc_missing, 
                    'duplicated_snps':duplicated_snps
                    }

    
    return ret_dict


def parse_phen_file(pf, pf_format, verbose=False, summary_dict=None):
    print(pf)
    phen_map = {}
    num_phens_found = 0
    if pf != None:
        if verbose:
            print('Parsing Phenotypes')
        if pf_format == 'FAM':
            """
            Individual's family ID ('FID')
            Individual's within-family ID ('IID'; cannot be '0')
            Within-family ID of father ('0' if father isn't in dataset)
            Within-family ID of mother ('0' if mother isn't in dataset)
            Sex code ('1' = male, '2' = female, '0' = unknown)
            Phenotype value ('1' = control, '2' = case, '-9'/'0'/non-numeric = missing data if case/control)
            """
            with open(pf, 'r') as f:
                for line in f:
                    l = line.split()
                    iid = l[1]
                    sex = int(l[4])
                    phen = float(l[5])
                    if sex != 0 and phen != -9:
                        phen_map[iid] = {'phen': phen, 'sex': sex}
                        num_phens_found +=1

            summary_dict[1]={'name':'Phenotype file (plink format):','value':pf}
            
        if pf_format == 'STANDARD':
            """
            IID   PHE
            """
            with open(pf, 'r') as f:
                for line in f:
                    l = line.split()
                    iid = l[0]
                    phen = float(l[1])
                    phen_map[iid] = {'phen': phen}
                    num_phens_found +=1

            summary_dict[1]={'name':'Phenotype file (STANDARD format):','value':pf}

        if pf_format == 'LSTANDARD':
            """
            FID IID PHE
            """
            with open(pf, 'r') as f:
                # Read first line and see if it is header
                line = f.readline()
                l = line.split()
                try:
                    iid=l[1]
                    phen=float(l[2])
                    phen_map[iid] = {'phen':phen}
                    num_phens_found+=1
                except:
                    # Do nothing
                    pass
                for line in f:
                    l = line.split()
                    iid=l[1]
                    phen=float(l[2])
                    phen_map[iid] = {'phen':phen}
                    num_phens_found+=1
            summary_dict[1]={'name':'Phenotype file (LSTANDARD format):','value':pf}
 
        elif pf_format == 'S2':
            """
            IID Age Sex Height_Inches
            """
            with open(pf, 'r') as f:
                for line in f:
                    l = line.split()
                    iid = l[0]
                    age = float(l[1])
                    if l[2] == 'Male':
                        sex = 1
                    elif l[2] == 'Female':
                        sex = 2
                    else:
                        raise Exception('Sex missing')
                    phen = float(l[3])
                    phen_map[iid] = {'phen': phen, 'age': age, 'sex': sex}
                    num_phens_found +=1
            summary_dict[1]={'name':'Phenotype file (S2 format):','value':pf}

    print("Parsed %d phenotypes successfully"%num_phens_found)
    return phen_map


def parse_ldpred_res(file_name):
    rs_id_map = {}
    """
    chrom    pos    sid    nt1    nt2    raw_beta    ldpred_inf_beta    ldpred_beta
    1, 798959, rs11240777, C, T, -1.1901e-02, 3.2443e-03, 2.6821e-04
    """
    with open(file_name, 'r') as f:
        next(f)
        for line in f:
            l = line.split()
            chrom_str = l[0]
            chrom = int(chrom_str[6:])
            pos = int(l[1])
            rs_id = l[2].strip()
            nt1 = l[3].strip()
            nt2 = l[4].strip()
            nts = [nt1, nt2]
            raw_beta = float(l[5])
            upd_pval_beta = float(l[6])
            rs_id_map[rs_id] = {'chrom': chrom, 'pos': pos, 'nts': nts, 'raw_beta': raw_beta,
                                'upd_pval_beta': upd_pval_beta}
    return rs_id_map


def parse_pt_res(file_name):
    non_zero_chromosomes = set()
    rs_id_map = {}
    """
    chrom    pos    sid    nt1    nt2    raw_beta    raw_pval_beta    upd_beta    upd_pval_beta
    1    798959    rs11240777    C    T    -1.1901e-02    -1.1901e-02    2.6821e-04    2.6821e-04
    """
    with open(file_name, 'r') as f:
        next(f)
        for line in f:
            l = line.split()
            chrom_str = l[0]
            chrom = int(chrom_str[6:])
            pos = int(l[1])
            rs_id = l[2].strip()
            nt1 = l[3].strip()
            nt2 = l[4].strip()
            nts = [nt1, nt2]
            raw_beta = float(l[5])
            upd_pval_beta = float(l[8])
            if raw_beta != 0:
                rs_id_map[rs_id] = {'chrom': chrom, 'pos': pos, 'nts': nts, 'raw_beta': raw_beta,
                                    'upd_pval_beta': upd_pval_beta}
                non_zero_chromosomes.add(chrom)

    return rs_id_map, non_zero_chromosomes

def write_scores_file(out_file, prs_dict, pval_derived_effects_prs, adj_pred_dict, 
                      output_regression_betas=False, weights_dict=None, verbose=False):
    num_individs = len(prs_dict['iids'])
    with open(out_file, 'w') as f:
        if verbose:
            print ('Writing polygenic scores to file %s'%out_file)
        out_str = 'IID, true_phens, PRS'
        if 'sex' in prs_dict:
            out_str = out_str + ', sex'
        if 'pcs' in prs_dict:
            pcs_str = ', '.join(['PC%d' % (1 + pc_i)
                                 for pc_i in range(len(prs_dict['pcs'][0]))])
            out_str = out_str + ', ' + pcs_str
        out_str += '\n'
        f.write(out_str)
        for i in range(num_individs):
            out_str = '%s, %0.6e, %0.6e' % (prs_dict['iids'][i], prs_dict['true_phens'][i],
                                                     pval_derived_effects_prs[i])
            if 'sex' in prs_dict:
                out_str = out_str + ', %d' % prs_dict['sex'][i]
            if 'pcs' in prs_dict:
                pcs_str = ', '.join(map(str, prs_dict['pcs'][i]))
                out_str = out_str +', '+ pcs_str
            out_str += '\n'
            f.write(out_str)

    if len(list(adj_pred_dict.keys())) > 0:
        with open(out_file + '.adj', 'w') as f:
            adj_prs_labels = list(adj_pred_dict.keys())
            out_str = 'IID, true_phens, PRS, ' + \
                ', '.join(adj_prs_labels)
            out_str += '\n'
            f.write(out_str)
            for i in range(num_individs):
                out_str = '%s, %0.6e, %0.6e' % (prs_dict['iids'][i], prs_dict['true_phens'][i],
                                                       pval_derived_effects_prs[i])
                for adj_prs in adj_prs_labels:
                    out_str += ', %0.4f' % adj_pred_dict[adj_prs][i]
                out_str += '\n'
                f.write(out_str)
    if output_regression_betas and weights_dict != None:
        hdf5file = out_file + '.weights.hdf5'
        if verbose:
            print ('Writing PRS regression weights to file %s'%hdf5file)
        oh5f = h5py.File(hdf5file, 'w')
        for k1 in list(weights_dict.keys()):
            kg = oh5f.create_group(k1)
            for k2 in weights_dict[k1]:
                kg.create_dataset(k2, data=sp.array(weights_dict[k1][k2]))
        oh5f.close()


def write_only_scores_file(out_file, prs_dict, pval_derived_effects_prs):
    num_individs = len(prs_dict['iids'])
    with open(out_file, 'w') as f:
        print ('Writing polygenic scores to file %s'%out_file)
        out_str = 'IID, PRS \n'
        f.write(out_str)
        for i in range(num_individs):
            out_str = '%s, %0.6e\n' % (prs_dict['iids'][i],
                                                     pval_derived_effects_prs[i])
            f.write(out_str)


def calc_risk_scores(bed_file, rs_id_map, phen_map, out_file=None, 
                     split_by_chrom=False, adjust_for_sex=False,
                     adjust_for_covariates=False, adjust_for_pcs=False, 
                     non_zero_chromosomes=None, only_score = False, 
                     verbose=False, summary_dict=None):
    if verbose:
        print('Parsing PLINK bed file: %s' % bed_file)
    if split_by_chrom:
        num_individs = len(phen_map)
        assert num_individs > 0, 'No individuals found.  Problems parsing the phenotype file?'
        pval_derived_effects_prs = sp.zeros(num_individs)

        for i in range(1, 23):
            if non_zero_chromosomes is None or i in non_zero_chromosomes:
                genotype_file = bed_file + '_%i_keep' % i
                if os.path.isfile(genotype_file + '.bed'):
                    if verbose:
                        print('Working on chromosome %d' % i)
                    prs_dict = get_prs(genotype_file, rs_id_map, phen_map, only_score=only_score, verbose=verbose)

                    pval_derived_effects_prs += prs_dict['pval_derived_effects_prs']
            elif verbose:
                    print('Skipping chromosome')

    else:
        prs_dict = get_prs(bed_file, rs_id_map, phen_map, only_score=only_score, verbose=verbose)
        num_individs = len(prs_dict['iids'])
        pval_derived_effects_prs = prs_dict['pval_derived_effects_prs']
        
    
    res_dict = {'num_snps':prs_dict['num_snps'], 'num_non_matching_nts':prs_dict['num_non_matching_nts'], 
                'num_flipped_nts':prs_dict['num_flipped_nts'], 'perc_missing':prs_dict['perc_missing'], 
                'duplicated_snps':prs_dict['duplicated_snps'], 'pred_r2': 0, 'corr_r2':0}
        
    if only_score:
        write_only_scores_file(out_file, prs_dict, pval_derived_effects_prs)
        
    elif sp.std(prs_dict['true_phens'])==0:
        if verbose:
            print('No variance left to explain in phenotype.')
    else:
        # Report prediction accuracy
        assert len(phen_map) > 0, 'No individuals found.  Problems parsing the phenotype file?'
        # Store covariate weights, slope, etc.
        weights_dict = {}
    
        # Store Adjusted predictions
        adj_pred_dict = {}

        #If there is no prediction, then output 0s.
        if sp.std(pval_derived_effects_prs)==0:

            weights_dict['unadjusted'] = {'Intercept': 0, 'ldpred_prs_effect': 0}
        else:
            pval_eff_corr = sp.corrcoef(pval_derived_effects_prs, prs_dict['true_phens'])[0, 1]
            pval_eff_r2 = pval_eff_corr ** 2
        
            res_dict['pred_r2'] = pval_eff_r2
            res_dict['corr_r2'] = pval_eff_corr
        
            pval_derived_effects_prs.shape = (len(pval_derived_effects_prs), 1)
            true_phens = sp.array(prs_dict['true_phens'])
            true_phens.shape = (len(true_phens), 1)
        
        
            # Direct effect
            Xs = sp.hstack([pval_derived_effects_prs, sp.ones((len(true_phens), 1))])
            (betas, rss00, r, s) = linalg.lstsq(
                sp.ones((len(true_phens), 1)), true_phens)
            (betas, rss, r, s) = linalg.lstsq(Xs, true_phens)
            pred_r2 = 1 - rss / rss00
            weights_dict['unadjusted'] = {
                'Intercept': betas[1][0], 'ldpred_prs_effect': betas[0][0]}
        
            if verbose:
                print('PRS trait correlation: %0.4f  R2: %0.4f' % (pval_eff_corr,pred_r2))
        
            # Adjust for sex
            if adjust_for_sex and 'sex' in prs_dict and len(prs_dict['sex']) > 0:
                sex = sp.array(prs_dict['sex'])
                sex.shape = (len(sex), 1)
                (betas, rss0, r, s) = linalg.lstsq(
                    sp.hstack([sex, sp.ones((len(true_phens), 1))]), true_phens)
                Xs = sp.hstack([pval_derived_effects_prs, sex,
                                sp.ones((len(true_phens), 1))])
                (betas, rss_pd, r, s) = linalg.lstsq(Xs, true_phens)
                weights_dict['sex_adj'] = {
                    'Intercept': betas[2][0], 'ldpred_prs_effect': betas[0][0], 'sex': betas[1][0]}
                if verbose:
                    print('Fitted effects (betas) for PRS, sex, and intercept on true phenotype:', betas)
                adj_pred_dict['sex_prs'] = sp.dot(Xs, betas)
                pred_r2 = 1 - rss_pd / rss0
                print('Variance explained (Pearson R2) by PRS adjusted for Sex: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['Sex_adj_pred_r2'] = pred_r2
                pred_r2 = 1 - rss_pd / rss00
                print('Variance explained (Pearson R2) by PRS + Sex : %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['Sex_adj_pred_r2+Sex'] = pred_r2
        
            # Adjust for PCs
            if adjust_for_pcs and 'pcs' in prs_dict and len(prs_dict['pcs']) > 0:
                pcs = prs_dict['pcs']
                (betas, rss0, r, s) = linalg.lstsq(
                    sp.hstack([pcs, sp.ones((len(true_phens), 1))]), true_phens)
                Xs = sp.hstack([pval_derived_effects_prs,
                                sp.ones((len(true_phens), 1)), pcs])
                (betas, rss_pd, r, s) = linalg.lstsq(Xs, true_phens)
                weights_dict['pc_adj'] = {
                    'Intercept': betas[1][0], 'ldpred_prs_effect': betas[0][0], 'pcs': betas[2][0]}
                adj_pred_dict['pc_prs'] = sp.dot(Xs, betas)
                pred_r2 = 1 - rss_pd / rss0
                print('Variance explained (Pearson R2) by PRS adjusted for PCs: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['PC_adj_pred_r2'] = pred_r2
                pred_r2 = 1 - rss_pd / rss00
                print('Variance explained (Pearson R2) by PRS + PCs: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['PC_adj_pred_r2+PC'] = pred_r2
        
                # Adjust for both PCs and Sex
                if adjust_for_sex and 'sex' in prs_dict and len(prs_dict['sex']) > 0:
                    sex = sp.array(prs_dict['sex'])
                    sex.shape = (len(sex), 1)
                    (betas, rss0, r, s) = linalg.lstsq(
                        sp.hstack([sex, pcs, sp.ones((len(true_phens), 1))]), true_phens)
                    Xs = sp.hstack([pval_derived_effects_prs, sex,
                                    sp.ones((len(true_phens), 1)), pcs])
                    (betas, rss_pd, r, s) = linalg.lstsq(Xs, true_phens)
                    weights_dict['sex_pc_adj'] = {
                        'Intercept': betas[2][0], 'ldpred_prs_effect': betas[0][0], 'sex': betas[1][0], 'pcs': betas[3][0]}
                    adj_pred_dict['sex_pc_prs'] = sp.dot(Xs, betas)
                    pred_r2 = 1 - rss_pd / rss0
                    print('Variance explained (Pearson R2) by PRS adjusted for PCs and Sex: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                    res_dict['PC_Sex_adj_pred_r2'] = pred_r2
                    pred_r2 = 1 - rss_pd / rss00
                    print('Variance explained (Pearson R2) by PRS+PCs+Sex: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                    res_dict['PC_Sex_adj_pred_r2+PC_Sex'] = pred_r2
        
            # Adjust for covariates
            if adjust_for_covariates and 'covariates' in prs_dict and len(prs_dict['covariates']) > 0:
                covariates = prs_dict['covariates']
                (betas, rss0, r, s) = linalg.lstsq(
                    sp.hstack([covariates, sp.ones((len(true_phens), 1))]), true_phens)
                Xs = sp.hstack([pval_derived_effects_prs, covariates,
                                sp.ones((len(true_phens), 1))])
                (betas, rss_pd, r, s) = linalg.lstsq(Xs, true_phens)
                adj_pred_dict['cov_prs'] = sp.dot(Xs, betas)
                pred_r2 = 1 - rss_pd / rss0
                print('Variance explained (Pearson R2) by PRS adjusted for Covariates: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['Cov_adj_pred_r2'] = pred_r2
                pred_r2 = 1 - rss_pd / rss00
                print('Variance explained (Pearson R2) by PRS + Cov: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                res_dict['Cov_adj_pred_r2+Cov'] = pred_r2
        
                if adjust_for_pcs and 'pcs' in prs_dict and len(prs_dict['pcs']) and 'sex' in prs_dict and len(prs_dict['sex']) > 0:
                    pcs = prs_dict['pcs']
                    sex = sp.array(prs_dict['sex'])
                    sex.shape = (len(sex), 1)
                    (betas, rss0, r, s) = linalg.lstsq(
                        sp.hstack([covariates, sex, pcs, sp.ones((len(true_phens), 1))]), true_phens)
                    Xs = sp.hstack([pval_derived_effects_prs, covariates,
                                    sex, pcs, sp.ones((len(true_phens), 1))])
                    (betas, rss_pd, r, s) = linalg.lstsq(Xs, true_phens)
                    adj_pred_dict['cov_sex_pc_prs'] = sp.dot(Xs, betas)
                    pred_r2 = 1 - rss_pd / rss0
                    print('Variance explained (Pearson R2) by PRS adjusted for Cov+PCs+Sex: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                    res_dict['Cov_PC_Sex_adj_pred_r2'] = pred_r2
                    pred_r2 = 1 - rss_pd / rss00
                    print('Variance explained (Pearson R2) by PRS+Cov+PCs+Sex: %0.4f (%0.6f)' % (pred_r2, (1 - pred_r2) / sp.sqrt(num_individs)))
                    res_dict['Cov_PC_Sex_adj_pred_r2+Cov_PC_Sex'] = pred_r2
        
        
            # Now calibration
            y_norm = (true_phens - sp.mean(true_phens)) / sp.std(true_phens)
            denominator = sp.dot(pval_derived_effects_prs.T, pval_derived_effects_prs)
            numerator = sp.dot(pval_derived_effects_prs.T, y_norm)
            regression_slope = (numerator / denominator)[0][0]
            if verbose:
                print('The slope for predictions with weighted effects is: %0.4f'% regression_slope)
        
    
        num_individs = len(prs_dict['pval_derived_effects_prs'])

        # Write PRS out to file.
        if out_file != None:
            write_scores_file(out_file, prs_dict, pval_derived_effects_prs, adj_pred_dict, 
                              weights_dict=weights_dict, verbose=verbose)

    return res_dict


def parse_covariates(p_dict,phen_map,summary_dict,verbose):
    with open(p_dict['cov_file'], 'r') as f:
        num_missing = 0
        for line in f:
            l = line.split()
            iid = l[0]
            if iid in phen_map:
                covariates = list(map(float, l[1:]))
                phen_map[iid]['covariates'] = covariates
            else:
                num_missing += 1
        if num_missing > 0:
            summary_dict[2.1]={'name':'Individuals w missing covariate information:','value':num_missing}
            if verbose:
                print('Unable to find %d iids in phen file!' % num_missing)
    summary_dict[2]={'name':'Parsed covariates file:','value':p_dict['cov_file']}


def parse_pcs(p_dict,phen_map,summary_dict,verbose):
    with open(p_dict['pcs_file'], 'r') as f:
        num_missing = 0
        for line in f:
            l = line.split()
            iid = l[1]
            if iid in phen_map:
                pcs = list(map(float, l[2:]))
                phen_map[iid]['pcs'] = pcs
            else:
                num_missing += 1
        if num_missing > 0:
            summary_dict[3.1]={'name':'Individuals w missing PCs:','value':num_missing}
            if verbose:
                print('Unable to find %d iids in phen file!' % num_missing)
    summary_dict[3]={'name':'Parsed PCs file:','value':p_dict['pcs_file']}

def main(p_dict):
    assert p_dict['summary_file'] is None or not p_dict['only_score'], 'Prediction summary file cannot be produced when the --only-score flag is set.'
    
    summary_dict = {}
    non_zero_chromosomes = set()
    verbose = p_dict['debug']

    t0 = time.time()

    summary_dict[0]={'name':'Validation genotype file (prefix):','value':p_dict['gf']}
    summary_dict[0.1]={'name':'Input weight file(s) (prefix):','value':p_dict['rf']}
    summary_dict[0.2]={'name':'Output scores file(s) (prefix):','value':p_dict['out']}

    adjust_for_pcs=False
    adjust_for_covs=False

    if not p_dict['only_score']:
        summary_dict[0.9]={'name':'dash', 'value':'Phenotypes'}
        if verbose:
            print('Parsing phenotypes')
        if p_dict['pf'] is None:
            if p_dict['gf'] is not None:
                phen_map = parse_phen_file(p_dict['gf'] + '.fam', 'FAM', verbose=verbose, summary_dict=summary_dict)
            else:
                raise Exception('Validation phenotypes were not found.')
        else:
            phen_map = parse_phen_file(p_dict['pf'], p_dict['pf_format'], verbose=verbose, summary_dict=summary_dict)
        t1 = time.time()
        t = (t1 - t0)
        summary_dict[1.1]={'name':'Individuals with phenotype information:','value':len(phen_map)}
        summary_dict[1.2]={'name':'Running time for parsing phenotypes:','value':'%d min and %0.2f secs'% (t / 60, t % 60)}
    
        if p_dict['cov_file'] != None:
            adjust_for_covs=True
            if verbose:
                print('Parsing additional covariates')
            parse_covariates(p_dict, phen_map, summary_dict, verbose)

    
        if p_dict['pcs_file']:
            adjust_for_pcs=True
            if verbose:
                print('Parsing PCs')
            parse_pcs(p_dict,phen_map,summary_dict,verbose)

    
        num_individs = len(phen_map)
        assert num_individs > 0, 'No phenotypes were found!'
    else:
        phen_map = None

    t0 = time.time()
    prs_file_is_missing = True
    res_dict = {}
    if p_dict['rf_format'] == 'LDPRED' or p_dict['rf_format']=='ANY':
        weights_file = '%s_LDpred-inf.txt' % (p_dict['rf'])
        
        if os.path.isfile(weights_file):
            print('')
            print('Calculating LDpred-inf risk scores')
            rs_id_map = parse_ldpred_res(weights_file)
            out_file = '%s_LDpred-inf.txt' % (p_dict['out'])
            res_dict['LDpred_inf'] = calc_risk_scores(p_dict['gf'], rs_id_map, phen_map, out_file=out_file, 
                             split_by_chrom=p_dict['split_by_chrom'],
                             adjust_for_pcs=adjust_for_pcs,
                             adjust_for_covariates=adjust_for_covs,
                             only_score=p_dict['only_score'],
                             verbose=verbose, summary_dict=summary_dict)
            if not p_dict['only_score']:
                summary_dict[5.2]={'name':'LDpred_inf (unadjusted) Pearson R2:','value':'%0.4f'%res_dict['LDpred_inf']['pred_r2']}
            prs_file_is_missing = False
        
       
        best_ldpred_pred_r2 = 0
        best_p = None
        for p in p_dict['f']:
            weights_file = '%s_LDpred_p%0.4e.txt' % (p_dict['rf'], p)
            if os.path.isfile(weights_file):
                print('')
                print('Calculating LDpred risk scores using f=%0.3e' % p)
                rs_id_map = parse_ldpred_res(weights_file)
                out_file = '%s_LDpred_p%0.4e.txt' % (p_dict['out'], p)
                method_str = 'LDpred_p%0.4e' % (p)
                res_dict[method_str] = calc_risk_scores(p_dict['gf'], rs_id_map, phen_map, out_file=out_file,
                                                        split_by_chrom=p_dict['split_by_chrom'],
                                                        adjust_for_pcs=adjust_for_pcs,
                                                        adjust_for_covariates=adjust_for_covs,
                                                        only_score=p_dict['only_score'],
                                                        verbose=verbose, summary_dict=summary_dict)
                if len(res_dict[method_str]) and (res_dict[method_str]['pred_r2']) >best_ldpred_pred_r2:
                    best_ldpred_pred_r2 = res_dict[method_str]['pred_r2']
                    best_p = p
        
                prs_file_is_missing=False
        if best_ldpred_pred_r2>0 and not p_dict['only_score']:         
            summary_dict[5.3]={'name':'Best LDpred (f=%0.2e) (unadjusted) R2:'%(best_p),'value':'%0.4f'%best_ldpred_pred_r2}

        best_ldpred_fast_pred_r2 = 0
        best_p = None
        for p in p_dict['f']:
            weights_file = '%s_LDpred_fast_p%0.4e.txt' % (p_dict['rf'], p)
            if os.path.isfile(weights_file):
                print('')
                print('Calculating LDpred-fast risk scores using f=%0.3e' % p)
                rs_id_map = parse_ldpred_res(weights_file)
                out_file = '%s_LDpred_fast_p%0.4e.txt' % (p_dict['out'], p)
                method_str = 'LDpred_fast_p%0.4e' % (p)
                res_dict[method_str] = calc_risk_scores(p_dict['gf'], rs_id_map, phen_map, out_file=out_file,
                                                        split_by_chrom=p_dict['split_by_chrom'],
                                                        adjust_for_pcs=adjust_for_pcs,
                                                        adjust_for_covariates=adjust_for_covs,
                                                        only_score=p_dict['only_score'],
                                                        verbose=verbose, summary_dict=summary_dict)
                if len(res_dict[method_str]) and (res_dict[method_str]['pred_r2']) >best_ldpred_fast_pred_r2:
                    best_ldpred_fast_pred_r2 = res_dict[method_str]['pred_r2']
                    best_p = p
        
                prs_file_is_missing=False
        if best_ldpred_fast_pred_r2>0 and not p_dict['only_score']:         
            summary_dict[5.4]={'name':'Best LDpred-fast (f=%0.2e) (unadjusted) R2:'%(best_p),'value':'%0.4f'%best_ldpred_fast_pred_r2}

        
        # Plot results?

    if p_dict['rf_format'] == 'P+T' or p_dict['rf_format']=='ANY':

        best_pt_pred_r2 = 0
        best_t = None
        best_r2 = None
        for max_r2 in p_dict['r2']:
            for p_thres in p_dict['p']:
                weights_file = '%s_P+T_r%0.2f_p%0.4e.txt' % (p_dict['rf'], max_r2, p_thres)
                if os.path.isfile(weights_file):
                    print('Calculating P+T risk scores using p-value threshold of %0.3e, and r2 threshold of %0.2f' % (p_thres, max_r2))
                    rs_id_map, non_zero_chromosomes = parse_pt_res(weights_file)
                    if len(rs_id_map)>0:
                        out_file = '%s_P+T_r%0.2f_p%0.4e.txt' % (p_dict['out'], max_r2, p_thres)
                        method_str = 'P+T_r%0.2f_p%0.4e' % (max_r2,p_thres)
                        res_dict[method_str] = calc_risk_scores(p_dict['gf'], rs_id_map, phen_map, out_file=out_file,
                                                                split_by_chrom=p_dict['split_by_chrom'],
                                                                non_zero_chromosomes=non_zero_chromosomes, 
                                                                adjust_for_pcs=adjust_for_pcs,
                                                                adjust_for_covariates=adjust_for_covs,
                                                                only_score=p_dict['only_score'],
                                                                verbose=verbose, summary_dict=summary_dict)
                        if len(res_dict[method_str]) and (res_dict[method_str]['pred_r2']) >best_pt_pred_r2:
                            best_pt_pred_r2 = res_dict[method_str]['pred_r2']
                            best_t = p_thres
                            best_r2 = max_r2
                    else:
                        print('No SNPs found with p-values below the given threshold.')
                    prs_file_is_missing=False
        if best_pt_pred_r2>0 and not p_dict['only_score']:                
            summary_dict[5.5]={'name':'Best P+T (r2=%0.2f, p=%0.2e) (unadjusted) R2:'%(best_r2, best_t),'value':'%0.4f'%best_pt_pred_r2}

    # Plot results?
    assert not prs_file_is_missing, 'No SNP weights file was found.  A prefix to these should be provided via the --rf flag. Note that the prefix should exclude the _LDpred_.. extension or file ending. '

    res_summary_file = p_dict['summary_file']
    if res_summary_file is not None and not p_dict['only_score']:
        with open(res_summary_file,'w') as f:
            if verbose:
                print ('Writing Results Summary to file %s'%res_summary_file)
            out_str = 'Pred_Method    Pred_corr    Pred_R2    SNPs_used\n'
            f.write(out_str)
            for method_str in sorted(res_dict):
                out_str = '%s    %0.4f    %0.4f    %i\n'%(method_str, res_dict[method_str]['corr_r2'], res_dict[method_str]['pred_r2'], res_dict[method_str]['num_snps'])
                f.write(out_str)
                
    #Identifying the best prediction
    if not p_dict['only_score']:
        best_pred_r2 = 0
        best_method_str = None
        for method_str in res_dict:
            if len(res_dict[method_str]) and (res_dict[method_str]['pred_r2']) >best_pred_r2:
                best_pred_r2 = res_dict[method_str]['pred_r2']
                best_method_str = method_str
                
        if best_method_str is not None:
            print('The highest (unadjusted) Pearson R2 was %0.4f, and provided by %s'%(best_pred_r2,best_method_str))
            summary_dict[5.99]={'name':'dash', 'value':'Optimal polygenic score'}
            summary_dict[6]={'name':'Method with highest (unadjusted) Pearson R2:','value':best_method_str}
            summary_dict[6.1]={'name':'Best (unadjusted) Pearson R2:','value':'%0.4f'%best_pred_r2}
            if verbose:
                summary_dict[6.2]={'name':'Number of SNPs used','value':'%d'%res_dict[best_method_str]['num_snps']}
                summary_dict[6.3]={'name':'Number of SNPs flipped','value':'%d'%res_dict[best_method_str]['num_flipped_nts']}
                summary_dict[6.4]={'name':'Fraction of SNPs not found in validation data','value':'%0.4f'%res_dict[best_method_str]['perc_missing']}
                summary_dict[6.5]={'name':'Number of duplicated SNPs','value':'%d'%res_dict[best_method_str]['duplicated_snps']}
                summary_dict[6.6]={'name':'Number of non-matching nucleotides SNPs','value':'%d'%res_dict[best_method_str]['num_non_matching_nts']}
    t1 = time.time()
    t = (t1 - t0)
    summary_dict[4.9]={'name':'dash', 'value':'Scoring'}
    summary_dict[5.9]={'name':'Running time for calculating scores:','value':'%d min and %0.2f secs'% (t / 60, t % 60)}

    if prs_file_is_missing:
        print('SNP weights files were not found.  This could be due to a mis-specified --rf flag, or other issues.')
    
    reporting.print_summary(summary_dict,'Scoring Summary')
back to top