swh:1:snp:c614e79db7a4ccd4655b1a5633b8240eb96afd07
Raw File
Tip revision: b14b4002604ed97a23c0bb94c3f77ffdfe29972f authored by hougang liu on 14 January 2019, 16:39:51 UTC
katib should be able to be deployed in any namespace
Tip revision: b14b400
medianstopping.go
package earlystopping

import (
	"context"
	"errors"
	"log"
	"sort"
	"strconv"

	"github.com/kubeflow/katib/pkg/api"
	vdb "github.com/kubeflow/katib/pkg/db"
)

const (
	defaultLeastStep = 20
	defaultMargin    = 0
	defaultBurnIn    = 0
)

type MedianStoppingParam struct {
	LeastStep  int
	Margin     float64
	EvalMetric string
	BurnIn     int
}

type MedianStoppingRule struct {
	dbIf vdb.VizierDBInterface
}

func NewMedianStoppingRule() *MedianStoppingRule {
	var err error
	m := &MedianStoppingRule{}
	m.dbIf, err = vdb.New()
	if err != nil {
		log.Fatalf("Failed to open db connection: %v", err)
	}
	return m
}

func (m *MedianStoppingRule) parseEarlyStoppingParameters(sc *api.StudyConfig, eps []*api.EarlyStoppingParameter) (*MedianStoppingParam, error) {
	p := &MedianStoppingParam{LeastStep: defaultLeastStep, Margin: defaultMargin, EvalMetric: sc.ObjectiveValueName, BurnIn: defaultBurnIn}
	for _, ep := range eps {
		switch ep.Name {
		case "LeastStep":
			l, err := strconv.Atoi(ep.Value)
			if err != nil {
				log.Printf("Fail to puerse parameter %s : %s", ep.Name, ep.Value)
			} else {
				p.LeastStep = l
			}
		case "EvalMargin":
			mar, err := strconv.ParseFloat(ep.Value, 64)
			if err != nil {
				log.Printf("Fail to puerse parameter %s : %s", ep.Name, ep.Value)
			} else {
				p.Margin = mar
			}
		case "EvalMetrics":
			p.EvalMetric = ep.Value
		case "BurnInPeriod":
			b, err := strconv.Atoi(ep.Value)
			if err != nil {
				log.Printf("Fail to puerse parameter %s : %s", ep.Name, ep.Value)
			} else {
				p.BurnIn = b
			}
		default:
			log.Printf("Unknown EarlyStopping Parameter %v", ep.Name)
		}
	}
	log.Printf("Parameter: LeastStep %d, Margin %v, EvalMetric %s, BurnInPeriod %d", p.LeastStep, p.Margin, p.EvalMetric, p.BurnIn)
	return p, nil
}

func (m *MedianStoppingRule) getMedianRunningAverage(completedWorkerslogs [][]*vdb.WorkerLog, step int, burnin int) float64 {
	r := []float64{}
	var ra float64
	for _, cwl := range completedWorkerslogs {
		ra = 0
		var st int
		var errParce bool = false
		if step > len(cwl) {
			st = len(cwl)
		} else {
			st = step
		}
		for s := burnin; s < st; s++ {
			v, err := strconv.ParseFloat(cwl[s].Value, 64)
			if err != nil {
				log.Printf("Fail to Parse %s : %s", cwl[s].Name, cwl[s].Value)
				errParce = true
				break
			}
			ra += v
		}
		if errParce {
			continue
		}
		ra = ra / float64(st)
		r = append(r, ra)
	}
	if len(r) == 0 {
		return 0
	} else {
		sort.Float64s(r)
		return r[len(r)/2]
	}
}

func (m *MedianStoppingRule) getBestValue(sid string, sc *api.StudyConfig, logs []*vdb.WorkerLog) (float64, error) {
	if len(logs) == 0 {
		return 0, errors.New("Evaluation Log is missing")
	}
	ot := sc.OptimizationType
	if ot != api.OptimizationType_MAXIMIZE && ot != api.OptimizationType_MINIMIZE {
		return 0, errors.New("OptimizationType Unknown.")
	}
	var ret float64
	var targetObjlog []float64
	for _, l := range logs {
		v, err := strconv.ParseFloat(l.Value, 64)
		if err != nil {
			log.Printf("Fail to Parse %s : %s", l.Name, l.Value)
			continue
		}
		targetObjlog = append(targetObjlog, v)
	}
	if len(targetObjlog) == 0 {
		return 0, errors.New("No Objective value log in Logs")
	}
	sort.Float64s(targetObjlog)
	if ot == api.OptimizationType_MAXIMIZE {
		ret = targetObjlog[len(targetObjlog)-1]
	} else if ot == api.OptimizationType_MINIMIZE {
		ret = targetObjlog[0]
	}
	return ret, nil
}
func (m *MedianStoppingRule) GetShouldStopWorkers(ctx context.Context, in *api.GetShouldStopWorkersRequest) (*api.GetShouldStopWorkersReply, error) {
	wl, err := m.dbIf.GetWorkerList(in.StudyId, "")
	if err != nil {
		return &api.GetShouldStopWorkersReply{}, err
	}
	sc, err := m.dbIf.GetStudyConfig(in.StudyId)
	if err != nil {
		return &api.GetShouldStopWorkersReply{}, err
	}
	eparam, err := m.dbIf.GetEarlyStopParam(in.ParamId)
	if err != nil {
		return &api.GetShouldStopWorkersReply{}, err
	}
	p, err := m.parseEarlyStoppingParameters(sc, eparam)
	if err != nil {
		return &api.GetShouldStopWorkersReply{}, err
	}

	rwids := []string{}
	cwl := make([][]*vdb.WorkerLog, 0, len(wl))
	sW := []string{}
	for _, w := range wl {
		switch w.Status {
		case api.State_RUNNING:
			rwids = append(rwids, w.WorkerId)
		case api.State_COMPLETED:
			wl, err := m.dbIf.GetWorkerLogs(w.WorkerId, &vdb.GetWorkerLogOpts{Name: p.EvalMetric})
			if err != nil {
				log.Printf("Fail to get worker %v logs", w.WorkerId)
				continue
			}
			if len(wl) > p.BurnIn {
				cwl = append(cwl, wl)
			}
		default:
		}
	}
	if len(cwl) == 0 {
		return &api.GetShouldStopWorkersReply{}, err
	}
	for _, w := range rwids {
		wl, err := m.dbIf.GetWorkerLogs(w, &vdb.GetWorkerLogOpts{Name: p.EvalMetric})
		if err != nil {
			log.Printf("Fail to get worker %v logs", w)
			continue
		}
		if len(wl) < p.LeastStep || len(wl) <= p.BurnIn {
			continue
		}
		v, err := m.getBestValue(in.StudyId, sc, wl)
		if err != nil {
			log.Printf("Fail to Get Best Value at %s: %v Log:%v", w, err, wl)
			continue
		}
		om := m.getMedianRunningAverage(cwl, len(wl), p.BurnIn)
		log.Printf("Worker %s, In step %d Current value: %v Median value: %v\n", w, len(wl), v, om)
		if (v < (om-p.Margin) && sc.OptimizationType == api.OptimizationType_MAXIMIZE) || v > (om+p.Margin) && sc.OptimizationType == api.OptimizationType_MINIMIZE {
			log.Printf("Worker %s shuold be stopped", w)
			sW = append(sW, w)
		}
	}
	return &api.GetShouldStopWorkersReply{ShouldStopWorkerIds: sW}, nil
}
back to top