https://github.com/google/cayley
Raw File
Tip revision: 15ab599d6aa8215cbbcd4699fe9012230eb3069c authored by Barak Michener on 23 March 2016, 20:15:00 UTC
Merge pull request #386 from barakmich/optional_path
Tip revision: 15ab599
sql_link_iterator.go
// Copyright 2015 The Cayley Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

import (
	"fmt"
	"strings"
	"sync/atomic"

	"github.com/barakmich/glog"
	"github.com/google/cayley/graph"
	"github.com/google/cayley/quad"
)

var sqlTableID uint64

func init() {
	atomic.StoreUint64(&sqlTableID, 0)
}

func newTableName() string {
	id := atomic.AddUint64(&sqlTableID, 1)
	return fmt.Sprintf("t_%d", id)
}

type constraint struct {
	dir  quad.Direction
	vals []string
}

type tagDir struct {
	tag       string
	dir       quad.Direction
	table     string
	justLocal bool
}

func (t tagDir) String() string {
	if t.dir == quad.Any {
		if t.justLocal {
			return fmt.Sprintf("%s.__execd as \"%s\", %s.__execd_hash as %s_hash", t.table, t.tag, t.table, t.tag)
		}
		return fmt.Sprintf("%s.\"%s\" as \"%s\", %s.%s_hash as %s_hash", t.table, t.tag, t.tag, t.table, t.tag, t.tag)
	}
	return fmt.Sprintf("%s.%s as \"%s\", %s.%s_hash as %s_hash", t.table, t.dir, t.tag, t.table, t.dir, t.tag)
}

type tableDef struct {
	table  string
	name   string
	values []string
}

type sqlItDir struct {
	dir quad.Direction
	it  sqlIterator
}

type sqlIterator interface {
	sqlClone() sqlIterator

	buildSQL(next bool, val graph.Value) (string, []string)
	getTables() []tableDef
	getTags() []tagDir
	buildWhere() (string, []string)
	tableID() tagDir

	quickContains(graph.Value) (ok bool, result bool)
	buildResult(result []string, cols []string) map[string]string
	sameTopResult(target []string, test []string) bool

	Result() graph.Value
	Size(*QuadStore) (int64, bool)
	Describe() string
	Type() sqlQueryType
	Tagger() *graph.Tagger
}

type SQLLinkIterator struct {
	tagger graph.Tagger

	nodeIts     []sqlItDir
	constraints []constraint
	tableName   string
	size        int64
	tagdirs     []tagDir

	resultQuad quad.Quad
}

func (l *SQLLinkIterator) sqlClone() sqlIterator {
	m := &SQLLinkIterator{
		tableName:   l.tableName,
		size:        l.size,
		constraints: make([]constraint, len(l.constraints)),
		tagdirs:     make([]tagDir, len(l.tagdirs)),
	}
	for _, i := range l.nodeIts {
		m.nodeIts = append(m.nodeIts, sqlItDir{
			dir: i.dir,
			it:  i.it.sqlClone(),
		})
	}
	copy(m.constraints, l.constraints)
	copy(m.tagdirs, l.tagdirs)
	m.tagger.CopyFromTagger(l.Tagger())
	return m
}

func (l *SQLLinkIterator) Tagger() *graph.Tagger {
	return &l.tagger
}

func (l *SQLLinkIterator) Result() graph.Value {
	return l.resultQuad
}

func (l *SQLLinkIterator) Size(qs *QuadStore) (int64, bool) {
	if l.size != 0 {
		return l.size, true
	}
	if len(l.constraints) > 0 {
		l.size = qs.sizeForIterator(false, l.constraints[0].dir, l.constraints[0].vals[0])
	} else if len(l.nodeIts) > 1 {
		subsize, _ := l.nodeIts[0].it.(*SQLNodeIterator).Size(qs)
		return subsize * 20, false
	} else {
		return qs.Size(), false
	}
	return l.size, true
}

func (l *SQLLinkIterator) Describe() string {
	s, _ := l.buildSQL(true, nil)
	return fmt.Sprintf("SQL_LINK_QUERY: %s", s)
}

func (l *SQLLinkIterator) Type() sqlQueryType {
	return link
}

func (l *SQLLinkIterator) quickContains(v graph.Value) (bool, bool) {
	for _, c := range l.constraints {
		none := true
		desired := v.(quad.Quad).Get(c.dir)
		for _, s := range c.vals {
			if s == desired {
				none = false
				break
			}
		}
		if none {
			return true, false
		}
	}
	if len(l.nodeIts) == 0 {
		return true, true
	}
	return false, false
}

func (l *SQLLinkIterator) buildResult(result []string, cols []string) map[string]string {
	var q quad.Quad
	q.Subject = result[0]
	q.Predicate = result[1]
	q.Object = result[2]
	q.Label = result[3]
	l.resultQuad = q
	m := make(map[string]string)
	for i, c := range cols[4:] {
		m[c] = result[i+4]
	}
	return m
}

func (l *SQLLinkIterator) getTables() []tableDef {
	out := []tableDef{tableDef{table: "quads", name: l.tableName}}
	for _, i := range l.nodeIts {
		out = append(out, i.it.getTables()...)
	}
	return out
}

func (l *SQLLinkIterator) getTags() []tagDir {
	var out []tagDir
	for _, tag := range l.tagger.Tags() {
		out = append(out, tagDir{
			dir:   quad.Any,
			table: l.tableName,
			tag:   tag,
		})
	}
	for _, tag := range l.tagdirs {
		out = append(out, tagDir{
			dir:   tag.dir,
			table: l.tableName,
			tag:   tag.tag,
		})

	}
	for _, i := range l.nodeIts {
		out = append(out, i.it.getTags()...)
	}
	return out
}

func (l *SQLLinkIterator) buildWhere() (string, []string) {
	var q []string
	var vals []string
	for _, c := range l.constraints {
		if len(c.vals) == 1 {
			q = append(q, fmt.Sprintf("%s.%s_hash = ?", l.tableName, c.dir))
			vals = append(vals, hashOf(c.vals[0]))
		} else if len(c.vals) > 1 {
			valslots := strings.Join(strings.Split(strings.Repeat("?", len(c.vals)), ""), ", ")
			subq := fmt.Sprintf("%s.%s_hash IN (%s)", l.tableName, c.dir, valslots)
			q = append(q, subq)
			for _, v := range c.vals {
				vals = append(vals, hashOf(v))
			}
		}
	}
	for _, i := range l.nodeIts {
		t := i.it.tableID()
		dir := t.dir.String()
		if t.dir == quad.Any {
			dir = t.tag
		}
		q = append(q, fmt.Sprintf("%s.%s_hash = %s.%s_hash", l.tableName, i.dir, t.table, dir))
	}
	for _, i := range l.nodeIts {
		s, v := i.it.buildWhere()
		q = append(q, s)
		vals = append(vals, v...)
	}
	query := strings.Join(q, " AND ")
	return query, vals
}

func (l *SQLLinkIterator) tableID() tagDir {
	return tagDir{
		dir:   quad.Any,
		table: l.tableName,
	}
}

func (l *SQLLinkIterator) buildSQL(next bool, val graph.Value) (string, []string) {
	query := "SELECT "
	t := []string{
		fmt.Sprintf("%s.subject", l.tableName),
		fmt.Sprintf("%s.predicate", l.tableName),
		fmt.Sprintf("%s.object", l.tableName),
		fmt.Sprintf("%s.label", l.tableName),
	}
	for _, v := range l.getTags() {
		t = append(t, v.String())
	}
	query += strings.Join(t, ", ")
	query += " FROM "
	t = []string{}
	var values []string
	for _, k := range l.getTables() {
		values = append(values, k.values...)
		t = append(t, fmt.Sprintf("%s as %s", k.table, k.name))
	}
	query += strings.Join(t, ", ")
	constraint, wherevalues := l.buildWhere()
	if constraint != "" {
		query += " WHERE "
	}

	values = append(values, wherevalues...)
	if !next {
		v := val.(quad.Quad)
		if constraint != "" {
			constraint += " AND "
		} else {
			constraint += " WHERE "
		}
		t = []string{
			fmt.Sprintf("%s.subject_hash = ?", l.tableName),
			fmt.Sprintf("%s.predicate_hash = ?", l.tableName),
			fmt.Sprintf("%s.object_hash = ?", l.tableName),
			fmt.Sprintf("%s.label_hash = ?", l.tableName),
		}
		constraint += strings.Join(t, " AND ")
		values = append(values, hashOf(v.Subject))
		values = append(values, hashOf(v.Predicate))
		values = append(values, hashOf(v.Object))
		values = append(values, hashOf(v.Label))
	}
	query += constraint
	query += ";"

	if glog.V(4) {
		dstr := query
		for i := 1; i <= len(values); i++ {
			dstr = strings.Replace(dstr, "?", fmt.Sprintf("'%s'", values[i-1]), 1)
		}
		glog.V(4).Infoln(dstr)
	}
	return query, values
}

func (l *SQLLinkIterator) sameTopResult(target []string, test []string) bool {
	return target[0] == test[0] && target[1] == test[1] && target[2] == test[2] && target[3] == test[3]
}
back to top