swh:1:snp:c614e79db7a4ccd4655b1a5633b8240eb96afd07
Tip revision: 4618aaf358b3c6cb83dc1647bc8ee1b44bd75fc9 authored by YujiOshima on 17 September 2018, 08:08:50 UTC
update only when the instance was changed
update only when the instance was changed
Tip revision: 4618aaf
interface.go
package db
import (
crand "crypto/rand"
"database/sql"
"errors"
"fmt"
"github.com/golang/protobuf/jsonpb"
"log"
"math/big"
"math/rand"
"strings"
"time"
api "github.com/kubeflow/katib/pkg/api"
_ "github.com/go-sql-driver/mysql"
)
const (
db_driver = "mysql"
db_name = "root:test@tcp(vizier-db:3306)/vizier"
mysql_time_fmt = "2006-01-02 15:04:05.999999"
)
type GetWorkerLogOpts struct {
Name string
SinceTime *time.Time
Descending bool
Limit int32
Objective bool
}
type WorkerLog struct {
Time time.Time
Name string
Value string
}
type VizierDBInterface interface {
DB_Init()
GetStudyConfig(string) (*api.StudyConfig, error)
GetStudyList() ([]string, error)
CreateStudy(*api.StudyConfig) (string, error)
DeleteStudy(string) error
GetTrial(string) (*api.Trial, error)
GetTrialList(string) ([]*api.Trial, error)
CreateTrial(*api.Trial) error
DeleteTrial(string) error
GetWorker(string) (*api.Worker, error)
GetWorkerStatus(string) (*api.State, error)
GetWorkerList(string, string) ([]*api.Worker, error)
GetWorkerLogs(string, *GetWorkerLogOpts) ([]*WorkerLog, error)
GetWorkerTimestamp(string) (*time.Time, error)
StoreWorkerLogs(string, []*api.MetricsLog) error
CreateWorker(*api.Worker) (string, error)
UpdateWorker(string, api.State) error
DeleteWorker(string) error
SetSuggestionParam(string, string, []*api.SuggestionParameter) (string, error)
UpdateSuggestionParam(string, []*api.SuggestionParameter) error
GetSuggestionParam(string) ([]*api.SuggestionParameter, error)
GetSuggestionParamList(string) ([]*api.SuggestionParameterSet, error)
SetEarlyStopParam(string, string, []*api.EarlyStoppingParameter) (string, error)
UpdateEarlyStopParam(string, []*api.EarlyStoppingParameter) error
GetEarlyStopParam(string) ([]*api.EarlyStoppingParameter, error)
GetEarlyStopParamList(string) ([]*api.EarlyStoppingParameterSet, error)
}
type db_conn struct {
db *sql.DB
}
var rs1Letters = []rune("abcdefghijklmnopqrstuvwxyz")
func NewWithSqlConn(db *sql.DB) VizierDBInterface {
d := new(db_conn)
d.db = db
seed, err := crand.Int(crand.Reader, big.NewInt(1<<63-1))
if err != nil {
log.Fatalf("RNG initialization failed: %v", err)
}
// We can do the following instead, but it creates a locking issue
//d.rng = rand.New(rand.NewSource(seed.Int64()))
rand.Seed(seed.Int64())
return d
}
func New() VizierDBInterface {
db, err := sql.Open(db_driver, db_name)
if err != nil {
log.Fatalf("DB open failed: %v", err)
}
return NewWithSqlConn(db)
}
func generate_randid() string {
// UUID isn't quite handy in the Go world
id_ := make([]byte, 8)
_, err := rand.Read(id_)
if err != nil {
log.Printf("Error reading random: %v", err)
return ""
}
return string(rs1Letters[rand.Intn(len(rs1Letters))]) + fmt.Sprintf("%016x", id_)[1:]
}
func (d *db_conn) GetStudyConfig(id string) (*api.StudyConfig, error) {
row := d.db.QueryRow("SELECT * FROM studies WHERE id = ?", id)
study := new(api.StudyConfig)
var dummy_id, configs, tags, metrics string
err := row.Scan(&dummy_id,
&study.Name,
&study.Owner,
&study.OptimizationType,
&study.OptimizationGoal,
&configs,
&tags,
&study.ObjectiveValueName,
&metrics,
&study.JobId,
)
if err != nil {
return nil, err
}
study.ParameterConfigs = new(api.StudyConfig_ParameterConfigs)
err = jsonpb.UnmarshalString(configs, study.ParameterConfigs)
if err != nil {
return nil, err
}
var tags_array []string
if len(tags) > 0 {
tags_array = strings.Split(tags, ",\n")
}
study.Tags = make([]*api.Tag, len(tags_array))
for i, j := range tags_array {
tag := new(api.Tag)
err = jsonpb.UnmarshalString(j, tag)
if err != nil {
log.Printf("err unmarshal %s", j)
return nil, err
}
study.Tags[i] = tag
}
study.Metrics = strings.Split(metrics, ",\n")
return study, nil
}
func (d *db_conn) GetStudyList() ([]string, error) {
rows, err := d.db.Query("SELECT id FROM studies")
if err != nil {
return nil, err
}
defer rows.Close()
var result []string
for rows.Next() {
var id string
err = rows.Scan(&id)
if err != nil {
log.Printf("err scanning studies.id: %v", err)
continue
}
result = append(result, id)
}
return result, nil
}
func (d *db_conn) CreateStudy(in *api.StudyConfig) (string, error) {
if in.ParameterConfigs == nil {
return "", errors.New("ParameterConfigs must be set")
}
if in.JobId != "" {
row := d.db.QueryRow("SELECT * FROM studies WHERE job_id = ?", in.JobId)
dummy_study := new(api.StudyConfig)
var dummy_id, dummy_configs, dummy_tags, dummy_metrics, dummy_job_id string
err := row.Scan(&dummy_id,
&dummy_study.Name,
&dummy_study.Owner,
&dummy_study.OptimizationType,
&dummy_study.OptimizationGoal,
&dummy_configs,
&dummy_tags,
&dummy_study.ObjectiveValueName,
&dummy_metrics,
&dummy_job_id,
)
if err == nil {
return "", fmt.Errorf("Study %s in Job %s already exist.", in.Name, in.JobId)
}
}
configs, err := (&jsonpb.Marshaler{}).MarshalToString(in.ParameterConfigs)
if err != nil {
log.Fatalf("Error marshaling configs: %v", err)
}
tags := make([]string, len(in.Tags))
for i, elem := range in.Tags {
tags[i], err = (&jsonpb.Marshaler{}).MarshalToString(elem)
if err != nil {
log.Printf("Error marshalling %v: %v", elem, err)
continue
}
}
var isin bool = false
for _, m := range in.Metrics {
if m == in.ObjectiveValueName {
isin = true
}
}
if !isin {
in.Metrics = append(in.Metrics, in.ObjectiveValueName)
}
var study_id string
i := 3
for true {
study_id = generate_randid()
_, err := d.db.Exec(
"INSERT INTO studies VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
study_id,
in.Name,
in.Owner,
in.OptimizationType,
in.OptimizationGoal,
configs,
strings.Join(tags, ",\n"),
in.ObjectiveValueName,
strings.Join(in.Metrics, ",\n"),
in.JobId,
)
if err == nil {
break
} else {
errmsg := strings.ToLower(err.Error())
if strings.Contains(errmsg, "unique") || strings.Contains(errmsg, "duplicate") {
i--
if i > 0 {
continue
}
}
}
return "", err
}
for _, perm := range in.AccessPermissions {
_, err := d.db.Exec(
"INSERT INTO study_permissions (study_id, access_permission) "+
"VALUES (?, ?)",
study_id, perm)
if err != nil {
log.Printf("Error storing permission (%s, %s): %v",
study_id, perm, err)
}
}
return study_id, nil
}
func (d *db_conn) DeleteStudy(id string) error {
_, err := d.db.Exec("DELETE FROM studies WHERE id = ?", id)
return err
}
func (d *db_conn) getTrials(trial_id string, study_id string) ([]*api.Trial, error) {
var rows *sql.Rows
var err error
if trial_id != "" {
rows, err = d.db.Query("SELECT * FROM trials WHERE id = ?", trial_id)
} else if study_id != "" {
rows, err = d.db.Query("SELECT * FROM trials WHERE study_id = ?", study_id)
} else {
return nil, errors.New("trial_id or study_id must be set")
}
if err != nil {
return nil, err
}
var result []*api.Trial
for rows.Next() {
trial := new(api.Trial)
var parameters, tags string
err := rows.Scan(&trial.TrialId,
&trial.StudyId,
¶meters,
&trial.ObjectiveValue,
&tags,
)
if err != nil {
return nil, err
}
params := strings.Split(parameters, ",\n")
p := make([]*api.Parameter, len(params))
for i, pstr := range params {
if pstr == "" {
continue
}
p[i] = &api.Parameter{}
err := jsonpb.UnmarshalString(pstr, p[i])
if err != nil {
return nil, err
}
}
trial.ParameterSet = p
taglist := strings.Split(tags, ",\n")
t := make([]*api.Tag, len(taglist))
for i, tstr := range taglist {
t[i] = &api.Tag{}
if tstr == "" {
continue
}
err := jsonpb.UnmarshalString(tstr, t[i])
if err != nil {
return nil, err
}
}
trial.Tags = t
result = append(result, trial)
}
return result, nil
}
func (d *db_conn) GetTrial(id string) (*api.Trial, error) {
trials, err := d.getTrials(id, "")
if err != nil {
return nil, err
}
if len(trials) > 1 {
return nil, errors.New("multiple trials found")
} else if len(trials) == 0 {
return nil, errors.New("trials not found")
}
return trials[0], nil
}
func (d *db_conn) GetTrialList(id string) ([]*api.Trial, error) {
trials, err := d.getTrials("", id)
return trials, err
}
func (d *db_conn) CreateTrial(trial *api.Trial) error {
// This function sets trial.id, unlike old dbInsertTrials().
// Users should not overwrite trial.id
var err, lastErr error
params := make([]string, len(trial.ParameterSet))
for i, elem := range trial.ParameterSet {
params[i], err = (&jsonpb.Marshaler{}).MarshalToString(elem)
if err != nil {
log.Printf("Error marshalling trial.ParameterSet %v: %v",
elem, err)
lastErr = err
}
}
tags := make([]string, len(trial.Tags))
for i := range tags {
tags[i], err = (&jsonpb.Marshaler{}).MarshalToString(trial.Tags[i])
if err != nil {
log.Printf("Error marshalling trial.Tags %v: %v",
trial.Tags[i], err)
lastErr = err
}
}
var trial_id string
i := 3
for true {
trial_id = generate_randid()
_, err = d.db.Exec("INSERT INTO trials VALUES (?, ?, ?, ?, ?)",
trial_id, trial.StudyId, strings.Join(params, ",\n"),
trial.ObjectiveValue, strings.Join(tags, ",\n"))
if err == nil {
trial.TrialId = trial_id
break
} else {
errmsg := strings.ToLower(err.Error())
if strings.Contains(errmsg, "unique") || strings.Contains(errmsg, "duplicate") {
i--
if i > 0 {
continue
}
}
}
return err
}
return lastErr
}
func (d *db_conn) DeleteTrial(id string) error {
_, err := d.db.Exec("DELETE FROM trials WHERE id = ?", id)
return err
}
func (d *db_conn) GetWorkerLogs(id string, opts *GetWorkerLogOpts) ([]*WorkerLog, error) {
qstr := ""
qfield := []interface{}{id}
order := ""
if opts != nil {
if opts.SinceTime != nil {
qstr += " AND time >= ?"
qfield = append(qfield, opts.SinceTime)
}
if opts.Name != "" {
qstr += " AND name = ?"
qfield = append(qfield, opts.Name)
}
if opts.Objective {
qstr += " AND is_objective = 1"
}
if opts.Descending {
order = " DESC"
}
if opts.Limit > 0 {
order += fmt.Sprintf(" LIMIT %d", opts.Limit)
}
}
rows, err := d.db.Query("SELECT time, name, value FROM worker_metrics WHERE worker_id = ?"+
qstr+" ORDER BY time"+order, qfield...)
if err != nil {
return nil, err
}
var result []*WorkerLog
for rows.Next() {
log1 := new(WorkerLog)
var time_str string
err := rows.Scan(&time_str, &((*log1).Name), &((*log1).Value))
if err != nil {
log.Printf("Error scanning log: %v", err)
continue
}
log1.Time, err = time.Parse(mysql_time_fmt, time_str)
if err != nil {
log.Printf("Error parsing time %s: %v", time_str, err)
continue
}
result = append(result, log1)
}
return result, nil
}
func (d *db_conn) getWorkerLastlog(id string, value *string) (*time.Time, error) {
var last_timestamp string
var err error
if value != nil {
row := d.db.QueryRow("SELECT time, value FROM worker_lastlogs WHERE worker_id = ?", id)
err = row.Scan(&last_timestamp, value)
} else {
row := d.db.QueryRow("SELECT time FROM worker_lastlogs WHERE worker_id = ?", id)
err = row.Scan(&last_timestamp)
}
switch {
case err == sql.ErrNoRows:
return nil, nil
case err != nil:
return nil, err
default:
mt, err := time.Parse(mysql_time_fmt, last_timestamp)
if err != nil {
log.Printf("Error parsing time in log %s: %v",
last_timestamp, err)
return nil, err
}
return &mt, nil
}
}
func (d *db_conn) GetWorkerTimestamp(id string) (*time.Time, error) {
return d.getWorkerLastlog(id, nil)
}
func (d *db_conn) storeWorkerLog(worker_id string, time string, metrics_name string, metrics_value string, objective_value_name string) error {
is_objective := 0
if metrics_name == objective_value_name {
is_objective = 1
}
_, err := d.db.Exec("INSERT INTO worker_metrics (worker_id, time, name, value, is_objective) VALUES (?, ?, ?, ?, ?)",
worker_id, time, metrics_name, metrics_value, is_objective)
if err != nil {
return err
}
return nil
}
func (d *db_conn) StoreWorkerLogs(worker_id string, logs []*api.MetricsLog) error {
var lasterr error
var last_value string
db_t, err := d.getWorkerLastlog(worker_id, &last_value)
if err != nil {
log.Printf("Error getting last log timestamp: %v", err)
}
row := d.db.QueryRow("SELECT objective_value_name FROM workers "+
"JOIN (studies) ON (workers.study_id = studies.id) WHERE "+
"workers.id = ?", worker_id)
var objective_value_name string
err = row.Scan(&objective_value_name)
if err != nil {
log.Printf("Cannot get objective_value_name or metrics: %v", err)
return err
}
var formatted_time string
var ls []string
for _, mlog := range logs {
metrics_name := mlog.Name
for _, mv := range mlog.Values {
t, err := time.Parse(time.RFC3339Nano, mv.Time)
if err != nil {
log.Printf("Error parsing time %s: %v", mv.Time, err)
lasterr = err
continue
}
if db_t != nil && !t.After(*db_t) {
// db_t is from mysql and has microsec precision.
// This code assumes nanosec fractions are rounded down.
continue
}
// use UTC as mysql DATETIME lacks timezone
formatted_time = t.UTC().Format(mysql_time_fmt)
if db_t != nil {
// Parse again to get rounding effect
//reparsed_time, err := time.Parse(mysql_time_fmt, formatted_time)
//if reparsed_time == *db_t {
// if mv.Value == last_value {
// stored_logs are already in DB
// This assignment ensures the remaining
// logs will be stored in DB.
// db_t = nil
// continue
// }
// // We don't know this is necessary or not yet.
// stored_logs = append(stored_logs, &mv.Value)
// continue
//}
// (reparsed_time > *db_t) can be assumed
err = d.storeWorkerLog(worker_id,
db_t.UTC().Format(mysql_time_fmt),
metrics_name, mv.Value,
objective_value_name)
if err != nil {
log.Printf("Error storing log %s: %v", mv.Value, err)
lasterr = err
}
db_t = nil
} else {
err = d.storeWorkerLog(worker_id,
formatted_time,
metrics_name, mv.Value,
objective_value_name)
if err != nil {
log.Printf("Error storing log %s: %v", mv.Value, err)
lasterr = err
}
}
}
}
if lasterr != nil {
// If lastlog were updated, logs that couldn't be saved
// would be lost.
return lasterr
}
if len(ls) == 2 {
_, err = d.db.Exec("REPLACE INTO worker_lastlogs VALUES (?, ?, ?)",
worker_id, formatted_time, ls[1])
}
return err
}
func (d *db_conn) getWorkers(worker_id string, trial_id string, study_id string) ([]*api.Worker, error) {
var rows *sql.Rows
var err error
if worker_id != "" {
rows, err = d.db.Query("SELECT * FROM workers WHERE id = ?", worker_id)
} else if trial_id != "" {
rows, err = d.db.Query("SELECT * FROM workers WHERE trial_id = ?", trial_id)
} else if study_id != "" {
rows, err = d.db.Query("SELECT * FROM workers WHERE study_id = ?", study_id)
} else {
return nil, errors.New("worker_id, trial_id or study_id must be set")
}
if err != nil {
return nil, err
}
var result []*api.Worker
for rows.Next() {
worker := new(api.Worker)
var tags string
err := rows.Scan(
&worker.WorkerId,
&worker.StudyId,
&worker.TrialId,
&worker.Type,
&worker.Status,
&worker.TemplatePath,
&tags,
)
if err != nil {
return nil, err
}
taglist := strings.Split(tags, ",\n")
t := make([]*api.Tag, len(taglist))
for i, tstr := range taglist {
t[i] = &api.Tag{}
if tstr == "" {
continue
}
err := jsonpb.UnmarshalString(tstr, t[i])
if err != nil {
return nil, err
}
}
worker.Tags = t
result = append(result, worker)
}
return result, nil
}
func (d *db_conn) GetWorker(id string) (*api.Worker, error) {
workers, err := d.getWorkers(id, "", "")
if err != nil {
return nil, err
}
if len(workers) > 1 {
return nil, errors.New("multiple workers found")
} else if len(workers) == 0 {
return nil, errors.New("worker not found")
}
return workers[0], nil
}
func (d *db_conn) GetWorkerStatus(id string) (*api.State, error) {
status := api.State_ERROR
row := d.db.QueryRow("SELECT status FROM workers WHERE id = ?", id)
err := row.Scan(&status)
if err != nil {
return &status, err
}
return &status, nil
}
func (d *db_conn) GetWorkerList(sid string, tid string) ([]*api.Worker, error) {
workers, err := d.getWorkers("", tid, sid)
return workers, err
}
func (d *db_conn) CreateWorker(worker *api.Worker) (string, error) {
// Users should not overwrite worker.id
var err, lastErr error
tags := make([]string, len(worker.Tags))
for i := range tags {
tags[i], err = (&jsonpb.Marshaler{}).MarshalToString(worker.Tags[i])
if err != nil {
log.Printf("Error marshalling worker.Tags %v: %v",
worker.Tags[i], err)
lastErr = err
}
}
var worker_id string
i := 3
for true {
worker_id = generate_randid()
_, err = d.db.Exec("INSERT INTO workers VALUES (?, ?, ?, ?, ?, ?, ?)",
worker_id, worker.StudyId, worker.TrialId, worker.Type,
api.State_PENDING, worker.TemplatePath, strings.Join(tags, ",\n"))
if err == nil {
worker.WorkerId = worker_id
break
} else {
errmsg := strings.ToLower(err.Error())
if strings.Contains(errmsg, "unique") || strings.Contains(errmsg, "duplicate") {
i--
if i > 0 {
continue
}
}
}
return "", err
}
return worker.WorkerId, lastErr
}
func (d *db_conn) UpdateWorker(id string, newstatus api.State) error {
_, err := d.db.Exec("UPDATE workers SET status = ? WHERE id = ?", newstatus, id)
return err
}
func (d *db_conn) DeleteWorker(id string) error {
_, err := d.db.Exec("DELETE FROM workers WHERE id = ?", id)
return err
}
func (d *db_conn) SetSuggestionParam(algorithm string, studyId string, params []*api.SuggestionParameter) (string, error) {
var err error
ps := make([]string, len(params))
for i, elem := range params {
ps[i], err = (&jsonpb.Marshaler{}).MarshalToString(elem)
if err != nil {
log.Printf("Error marshalling %v: %v", elem, err)
return "", err
}
}
var paramId string
for true {
paramId = generate_randid()
_, err = d.db.Exec("INSERT INTO suggestion_param VALUES (?, ?, ?, ?)",
paramId, algorithm, studyId, strings.Join(ps, ",\n"))
if err == nil {
break
}
}
return paramId, err
}
func (d *db_conn) UpdateSuggestionParam(paramId string, params []*api.SuggestionParameter) error {
var err error
ps := make([]string, len(params))
for i, elem := range params {
ps[i], err = (&jsonpb.Marshaler{}).MarshalToString(elem)
if err != nil {
log.Printf("Error marshalling %v: %v", elem, err)
return err
}
}
_, err = d.db.Exec("UPDATE suggestion_param SET parameters = ? WHERE id = ?",
strings.Join(ps, ",\n"), paramId)
return err
}
func (d *db_conn) GetSuggestionParam(paramId string) ([]*api.SuggestionParameter, error) {
var params string
row := d.db.QueryRow("SELECT parameters FROM suggestion_param WHERE id = ?", paramId)
err := row.Scan(¶ms)
if err != nil {
return nil, err
}
var p_array []string
if len(params) > 0 {
p_array = strings.Split(params, ",\n")
} else {
return nil, nil
}
ret := make([]*api.SuggestionParameter, len(p_array))
for i, j := range p_array {
p := new(api.SuggestionParameter)
err = jsonpb.UnmarshalString(j, p)
if err != nil {
log.Printf("err unmarshal %s", j)
return nil, err
}
ret[i] = p
}
return ret, nil
}
func (d *db_conn) GetSuggestionParamList(studyId string) ([]*api.SuggestionParameterSet, error) {
var rows *sql.Rows
var err error
rows, err = d.db.Query("SELECT * FROM suggestion_param WHERE study_id = ?", studyId)
if err != nil {
return nil, err
}
var result []*api.SuggestionParameterSet
for rows.Next() {
var id string
var algorithm string
var params string
var sID string
err := rows.Scan(&id, &sID, &algorithm, ¶ms)
if err != nil {
return nil, err
}
if studyId != sID {
continue
}
var p_array []string
if len(params) > 0 {
p_array = strings.Split(params, ",\n")
} else {
return nil, nil
}
suggestparams := make([]*api.SuggestionParameter, len(p_array))
for i, j := range p_array {
p := new(api.SuggestionParameter)
err = jsonpb.UnmarshalString(j, p)
if err != nil {
log.Printf("err unmarshal %s", j)
return nil, err
}
suggestparams[i] = p
}
result = append(result, &api.SuggestionParameterSet{
ParamId: id,
SuggestionAlgorithm: algorithm,
SuggestionParameters: suggestparams,
})
}
return result, nil
}
func (d *db_conn) SetEarlyStopParam(algorithm string, studyId string, params []*api.EarlyStoppingParameter) (string, error) {
ps := make([]string, len(params))
var err error
for i, elem := range params {
ps[i], err = (&jsonpb.Marshaler{}).MarshalToString(elem)
if err != nil {
log.Printf("Error marshalling %v: %v", elem, err)
return "", err
}
}
var paramId string
for true {
paramId := generate_randid()
_, err = d.db.Exec("INSERT INTO earlystopping_param VALUES (?,?, ?, ?)",
paramId, algorithm, studyId, strings.Join(ps, ",\n"))
if err == nil {
break
}
}
return paramId, nil
}
func (d *db_conn) UpdateEarlyStopParam(paramId string, params []*api.EarlyStoppingParameter) error {
ps := make([]string, len(params))
var err error
for i, elem := range params {
ps[i], err = (&jsonpb.Marshaler{}).MarshalToString(elem)
if err != nil {
log.Printf("Error marshalling %v: %v", elem, err)
return err
}
}
_, err = d.db.Exec("UPDATE earlystopping_param SET parameters = ? WHERE id = ?",
strings.Join(ps, ",\n"), paramId)
return err
}
func (d *db_conn) GetEarlyStopParam(paramId string) ([]*api.EarlyStoppingParameter, error) {
var params string
row := d.db.QueryRow("SELECT parameters FROM earlystopping_param WHERE id = ?", paramId)
err := row.Scan(¶ms)
if err != nil {
return nil, err
}
var p_array []string
if len(params) > 0 {
p_array = strings.Split(params, ",\n")
} else {
return nil, nil
}
ret := make([]*api.EarlyStoppingParameter, len(p_array))
for i, j := range p_array {
p := new(api.EarlyStoppingParameter)
err = jsonpb.UnmarshalString(j, p)
if err != nil {
log.Printf("err unmarshal %s", j)
return nil, err
}
ret[i] = p
}
return ret, nil
}
func (d *db_conn) GetEarlyStopParamList(studyId string) ([]*api.EarlyStoppingParameterSet, error) {
var rows *sql.Rows
var err error
rows, err = d.db.Query("SELECT * FROM earlystopping_param WHERE study_id = ?", studyId)
if err != nil {
return nil, err
}
var result []*api.EarlyStoppingParameterSet
for rows.Next() {
var id string
var algorithm string
var params string
var sID string
err := rows.Scan(&id, &sID, &algorithm, ¶ms)
if err != nil {
return nil, err
}
if studyId != sID {
continue
}
var p_array []string
if len(params) > 0 {
p_array = strings.Split(params, ",\n")
} else {
return nil, nil
}
esparams := make([]*api.EarlyStoppingParameter, len(p_array))
for i, j := range p_array {
p := new(api.EarlyStoppingParameter)
err = jsonpb.UnmarshalString(j, p)
if err != nil {
log.Printf("err unmarshal %s", j)
return nil, err
}
esparams[i] = p
}
result = append(result, &api.EarlyStoppingParameterSet{
ParamId: id,
EarlyStoppingAlgorithm: algorithm,
EarlyStoppingParameters: esparams,
})
}
return result, nil
}