https://github.com/dotbot2000/elli
Tip revision: 9d45bdf6c08ad0d5cc831306f7b02640d10471de authored by Jane Lin on 21 October 2016, 09:52:03 UTC
Update mcmc_ages_2.py
Update mcmc_ages_2.py
Tip revision: 9d45bdf
star_ages.py
from __future__ import print_function
from numpy import *
from numpy.linalg import inv, det
import emcee
from numpy.random import rand, seed
from scipy.interpolate import interp1d
from Isochrones import DSED_Isochrones
from asteroseismic import *
#define constants
log2pi=log(2*pi)
class star:
def __init__(self):
self.Teff=None; self.sigma_Teff=None
self.logg=None; self.sigma_logg=None
self.FeH=None; self.sigma_FeH = None
self.kmag=None; self.sigma_kmag = None #apparent
self.Kmag=None; self.sigma_Kmag = None #absolute
self.parallax=None; self.sigma_parallax=None
self.delta_nu=None; self.sigma_delta_nu=None
self.nu_max=None; self.sigma_nu_max=None
self.ID=None; self.initial_guess=None
self.sampler=None
def pack(self):
self.data=[]
self.mask=[]
if self.Teff is not None:
self.data.append(self.Teff)
self.mask.append(True)
else:
self.mask.append(False)
if self.logg is not None:
self.data.append(self.logg)
self.mask.append(True)
else:
self.mask.append(False)
if self.FeH is not None:
self.data.append(self.FeH)
self.mask.append(True)
else:
self.mask.append(False)
if self.Kmag is not None:
self.mask.append(True)
self.data.append(self.Kmag)
else:
self.mask.append(False)
if self.delta_nu is not None:
self.data.append(self.delta_nu)
self.mask.append(True)
else:
self.mask.append(False)
if self.nu_max is not None:
self.data.append(self.nu_max)
self.mask.append(True)
else:
self.mask.append(False)
self.data=array(self.data)
self.mask=array(self.mask)
self.dim = len(self.data)
def set_absolute_Kmag(self):
self.distance, self.sigma_distance = parallax_distance(self.parallax, self.sigma_parallax)
self.Kmag, self.sigma_Kmag = Kmag_from_distance(self.kmag, self.sigma_kmag, self.distance, self.sigma_distance)
def set_covariance_matrix(self,cov):
self.cov = cov
#then invert it
self.icov=inv(self.cov)
#and take the determinant, and the log of that...
self.det_cov = det(self.cov)
self.log_det_cov = log(self.det_cov)
def run_emcee(self,nwalkers=200,nburn=100,nrun=400,a=3.):
#each thread has an independent random number sequence
seed()
ndim=3
#initial guess for this star
age_guess=7.0
mass_guess=0.8
feh_guess=self.FeH
guess=array([age_guess,mass_guess,feh_guess])
self.initial_guess = guess
#create a cloud around the guess for the walkers
p0=[guess*(1-0.2*(0.5-rand(ndim))) for i in range(nwalkers)]
#create an instance
self.sampler=emcee.EnsembleSampler(nwalkers,ndim,lnProb,args=[self],a=a)
#burn-in: save end state and reset
pos,prob,state,blob=self.sampler.run_mcmc(p0,nburn)
self.sampler.reset()
#main run
self.sampler.run_mcmc(pos,nrun)
def parallax_distance(parallax,err_parallax): #both in mas
d=(parallax*u.mas).to(u.parsec, equivalencies=u.parallax())
d=d.value
sigma_d=abs(d*err_parallax/float(parallax))
print('distance: '+str(d)+'+/-'+str(sigma_d)+' pc')
return d,sigma_d #both in pc
def Kmag_from_distance(kmag,err_kmag,d,err_d):
Kmag= kmag-5*(log10(d)-1 )
sigma_Kmag=sqrt( pow(err_kmag,2)+ 25*pow( err_d/ (float(d)*log(10) ),2) )
print('abs K mag: '+ str(Kmag)+'+/-'+str(sigma_Kmag))
return Kmag, sigma_Kmag #absolute kmag
def interp(x,y):
#return interp1d(x=x,y=y,kind='nearest')
return interp1d(x=x,y=y,kind='linear')
def get_star(input_params,y):
age=input_params[0]
mass=input_params[1]
FeH = input_params[2]
ok=False
result=empty(6)
#y is a sorted list of isochrones ordered by increasing [Fe/H]
if y[0].FeH <= FeH <= y[-1].FeH:
met=[]; Teff=[]; logg=[]; Kmag=[]; dnu=[]; numax=[]
for i in range(len(y)):
if y[i].FeH <= FeH < y[i+1].FeH:
iy=i
break
for x in y[iy:iy+2]:
ok,params=get_one_star(age,mass,x)
if ok:
met.append(x.FeH)
Teff.append(params[0])
logg.append(params[1])
Kmag.append(params[2])
dnu.append(params[3])
numax.append(params[4])
#now we have two lists,
#one filled with mets and the other with results
if len(met)>1 and met[0] <= FeH <= met[-1]:
result[0] = interp(met,Teff)(FeH)
result[1] = interp(met,logg)(FeH)
result[2] = FeH
result[3] = interp(met,Kmag)(FeH)
result[4] = interp(met,dnu)(FeH)
result[5] = interp(met,numax)(FeH)
return ok, result
def get_one_star(age,mass,x):
ok=False
params=empty(3)
if x.ages[0] <= age <= x.ages[-1]:
for i in range(len(x.ages)-1):
if x.ages[i] <= age <= x.ages[i+1]:
i0=i
i1=i+1
break
#linear age interpolation
m0=x.data[i0]['M_Mo']
m1=x.data[i1]['M_Mo']
if m0[0] <= mass <= m0[-1] and m1[0] <= mass <= m1[-1]:
T0=interp(m0,x.data[i0]['LogTeff'])(mass)
T1=interp(m1,x.data[i1]['LogTeff'])(mass)
g0=interp(m0,x.data[i0]['LogG'])(mass)
g1=interp(m1,x.data[i1]['LogG'])(mass)
K0=interp(m0,x.data[i0]['Ks '])(mass)
K1=interp(m1,x.data[i1]['Ks '])(mass)
L0=interp(m0,x.data[i0]['LogL_Lo'])(mass)
L1=interp(m1,x.data[i1]['LogL_Lo'])(mass)
alfa=(age-x.ages[i0])/(x.ages[i1]-x.ages[i0])
beta=1.0-alfa
Teff=alfa*pow(10,T1) + beta*pow(10,T0)
logg=alfa*g1 + beta*g0
Kmag=alfa*K1 + beta*K0
logL=alfa*L1 + beta*L0
luminosity = pow(10,logL)
dnu = delta_nu_func(mass,Teff,luminosity)
numax= nu_max_func(mass,Teff,luminosity)
params=array([Teff,logg,Kmag,dnu,numax])
ok=True
return ok,params
def lnPrior(m):
alpha=-2.35 #Salpeter IMF slope
m0=0.1
norm=-(alpha+1.)/pow(m0,alpha+1.)
return log(norm*pow(m,alpha))
def lnProb(params,star):
#params = [age, mass, feh]
ok,model=get_star(params,y)
N = star.dim
if ok:
#shrink the model array to only those values present in the data
mod = array([model[i] for i in range(len(star.mask)) if star.mask[i]])
diff = star.data - mod
#now calculate ln(probability)
return -0.5 * ( dot(diff, dot(star.icov,diff) ) + star.log_det_cov + N*log2pi ), model
else:
return -inf, model
if __name__ == '__main__':
iso_dir='/home/dotter/science/mcmc/iso'
iso_list=['fehm25afep6.UBVRIJHKsKp',
'fehm20afep4.UBVRIJHKsKp',
'fehm15afep4.UBVRIJHKsKp',
'fehm10afep4.UBVRIJHKsKp',
'fehm05afep2.UBVRIJHKsKp',
'fehp00afep0.UBVRIJHKsKp',
'fehp02afep0.UBVRIJHKsKp',
'fehp03afep0.UBVRIJHKsKp',
'fehp05afep0.UBVRIJHKsKp']
y=[]
for iso in iso_list:
y.append(DSED_Isochrones(iso_dir+'/'+iso))
if False:
print("Example: Sun")
x=star()
x.Teff=5777; x.sigma_Teff=3.
x.logg=4.4; x.sigma_logg=0.03
x.FeH=0.0; x.sigma_FeH=0.01
x.Kmag=3.302; x.sigma_Kmag=0.005
x.delta_nu = 135.1; x.sigma_delta_nu=0.1
x.nu_max = 3090.0; x.sigma_nu_max=30
covariance=zeros((6,6))
covariance[0,0]=pow(x.sigma_Teff,2)
covariance[1,1]=pow(x.sigma_logg,2)
covariance[2,2]=pow(x.sigma_FeH,2)
covariance[3,3]=pow(x.sigma_Kmag,2)
covariance[4,4]=pow(x.sigma_delta_nu,2)
covariance[5,5]=pow(x.sigma_nu_max,2)
#add off diagonal terms as needed...
x.set_covariance_matrix(covariance)
x.pack()
x.run_emcee(nwalkers=100,nrun=250)
C=cov(x.sampler.flatchain.T)
print(matrix(C))
print('mean age = ', mean(x.sampler.flatchain[:,0]))
print('std age = ', std(x.sampler.flatchain[:,0]))
print('mean mass= ', mean(x.sampler.flatchain[:,1]))
print('std mass= ', std(x.sampler.flatchain[:,1]))
print('mean Fe/H= ', mean(x.sampler.flatchain[:,2]))
print('std Fe/H= ', std(x.sampler.flatchain[:,2]))
print("\nExample: GALAH+Cannon+covariance")
w=star()
w.Teff=6224.146
w.logg=3.78807
w.FeH=-0.7095807
cov=array([[ 1.37034154e+01, 2.54319931e-03, 7.33258078e-03],
[ 2.54319931e-03, 6.03449909e-05, -5.09215250e-07],
[ 7.33258078e-03, -5.09215250e-07, 1.39315419e-05]])
w.pack()
w.set_covariance_matrix(cov)
w.run_emcee(nwalkers=100,nrun=250)
print('mean age = ', mean(w.sampler.flatchain[:,0]))
print('std age = ', std(w.sampler.flatchain[:,0]))
print('mean mass= ', mean(w.sampler.flatchain[:,1]))
print('std mass= ', std(w.sampler.flatchain[:,1]))
print('mean Fe/H= ', mean(w.sampler.flatchain[:,2]))
print('std Fe/H= ', std(w.sampler.flatchain[:,2]))
print("\nExample: GALAH+Cannon+ no covariance")
z=star()
z.Teff=6224.146
z.logg=3.78807
z.FeH=-0.7095807
cov=array([[ 1.37034154e+01, 0.0, 0.0],
[ 0.0, 6.03449909e-05, 0.0],
[ 0.0, 0.0, 1.39315419e-05]])
z.pack()
z.set_covariance_matrix(cov)
z.run_emcee(nwalkers=100,nrun=250)
print('mean age = ', mean(z.sampler.flatchain[:,0]))
print('std age = ', std(z.sampler.flatchain[:,0]))
print('mean mass= ', mean(z.sampler.flatchain[:,1]))
print('std mass= ', std(z.sampler.flatchain[:,1]))
print('mean Fe/H= ', mean(z.sampler.flatchain[:,2]))
print('std Fe/H= ', std(z.sampler.flatchain[:,2]))
if False:
from multiprocessing import Process
#read data from files or whatever into a list of stars() called star_data
#but don't be greedy about memory!
max_threads=4
for thread in range(max_threads):
p=Process(target=do_run_emcee, args=(star_data,thread,max_threads))
p.start()
p.join()