https://github.com/google/cayley
Tip revision: 15ab599d6aa8215cbbcd4699fe9012230eb3069c authored by Barak Michener on 23 March 2016, 20:15:00 UTC
Merge pull request #386 from barakmich/optional_path
Merge pull request #386 from barakmich/optional_path
Tip revision: 15ab599
sql_link_iterator.go
// Copyright 2015 The Cayley Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
"fmt"
"strings"
"sync/atomic"
"github.com/barakmich/glog"
"github.com/google/cayley/graph"
"github.com/google/cayley/quad"
)
var sqlTableID uint64
func init() {
atomic.StoreUint64(&sqlTableID, 0)
}
func newTableName() string {
id := atomic.AddUint64(&sqlTableID, 1)
return fmt.Sprintf("t_%d", id)
}
type constraint struct {
dir quad.Direction
vals []string
}
type tagDir struct {
tag string
dir quad.Direction
table string
justLocal bool
}
func (t tagDir) String() string {
if t.dir == quad.Any {
if t.justLocal {
return fmt.Sprintf("%s.__execd as \"%s\", %s.__execd_hash as %s_hash", t.table, t.tag, t.table, t.tag)
}
return fmt.Sprintf("%s.\"%s\" as \"%s\", %s.%s_hash as %s_hash", t.table, t.tag, t.tag, t.table, t.tag, t.tag)
}
return fmt.Sprintf("%s.%s as \"%s\", %s.%s_hash as %s_hash", t.table, t.dir, t.tag, t.table, t.dir, t.tag)
}
type tableDef struct {
table string
name string
values []string
}
type sqlItDir struct {
dir quad.Direction
it sqlIterator
}
type sqlIterator interface {
sqlClone() sqlIterator
buildSQL(next bool, val graph.Value) (string, []string)
getTables() []tableDef
getTags() []tagDir
buildWhere() (string, []string)
tableID() tagDir
quickContains(graph.Value) (ok bool, result bool)
buildResult(result []string, cols []string) map[string]string
sameTopResult(target []string, test []string) bool
Result() graph.Value
Size(*QuadStore) (int64, bool)
Describe() string
Type() sqlQueryType
Tagger() *graph.Tagger
}
type SQLLinkIterator struct {
tagger graph.Tagger
nodeIts []sqlItDir
constraints []constraint
tableName string
size int64
tagdirs []tagDir
resultQuad quad.Quad
}
func (l *SQLLinkIterator) sqlClone() sqlIterator {
m := &SQLLinkIterator{
tableName: l.tableName,
size: l.size,
constraints: make([]constraint, len(l.constraints)),
tagdirs: make([]tagDir, len(l.tagdirs)),
}
for _, i := range l.nodeIts {
m.nodeIts = append(m.nodeIts, sqlItDir{
dir: i.dir,
it: i.it.sqlClone(),
})
}
copy(m.constraints, l.constraints)
copy(m.tagdirs, l.tagdirs)
m.tagger.CopyFromTagger(l.Tagger())
return m
}
func (l *SQLLinkIterator) Tagger() *graph.Tagger {
return &l.tagger
}
func (l *SQLLinkIterator) Result() graph.Value {
return l.resultQuad
}
func (l *SQLLinkIterator) Size(qs *QuadStore) (int64, bool) {
if l.size != 0 {
return l.size, true
}
if len(l.constraints) > 0 {
l.size = qs.sizeForIterator(false, l.constraints[0].dir, l.constraints[0].vals[0])
} else if len(l.nodeIts) > 1 {
subsize, _ := l.nodeIts[0].it.(*SQLNodeIterator).Size(qs)
return subsize * 20, false
} else {
return qs.Size(), false
}
return l.size, true
}
func (l *SQLLinkIterator) Describe() string {
s, _ := l.buildSQL(true, nil)
return fmt.Sprintf("SQL_LINK_QUERY: %s", s)
}
func (l *SQLLinkIterator) Type() sqlQueryType {
return link
}
func (l *SQLLinkIterator) quickContains(v graph.Value) (bool, bool) {
for _, c := range l.constraints {
none := true
desired := v.(quad.Quad).Get(c.dir)
for _, s := range c.vals {
if s == desired {
none = false
break
}
}
if none {
return true, false
}
}
if len(l.nodeIts) == 0 {
return true, true
}
return false, false
}
func (l *SQLLinkIterator) buildResult(result []string, cols []string) map[string]string {
var q quad.Quad
q.Subject = result[0]
q.Predicate = result[1]
q.Object = result[2]
q.Label = result[3]
l.resultQuad = q
m := make(map[string]string)
for i, c := range cols[4:] {
m[c] = result[i+4]
}
return m
}
func (l *SQLLinkIterator) getTables() []tableDef {
out := []tableDef{tableDef{table: "quads", name: l.tableName}}
for _, i := range l.nodeIts {
out = append(out, i.it.getTables()...)
}
return out
}
func (l *SQLLinkIterator) getTags() []tagDir {
var out []tagDir
for _, tag := range l.tagger.Tags() {
out = append(out, tagDir{
dir: quad.Any,
table: l.tableName,
tag: tag,
})
}
for _, tag := range l.tagdirs {
out = append(out, tagDir{
dir: tag.dir,
table: l.tableName,
tag: tag.tag,
})
}
for _, i := range l.nodeIts {
out = append(out, i.it.getTags()...)
}
return out
}
func (l *SQLLinkIterator) buildWhere() (string, []string) {
var q []string
var vals []string
for _, c := range l.constraints {
if len(c.vals) == 1 {
q = append(q, fmt.Sprintf("%s.%s_hash = ?", l.tableName, c.dir))
vals = append(vals, hashOf(c.vals[0]))
} else if len(c.vals) > 1 {
valslots := strings.Join(strings.Split(strings.Repeat("?", len(c.vals)), ""), ", ")
subq := fmt.Sprintf("%s.%s_hash IN (%s)", l.tableName, c.dir, valslots)
q = append(q, subq)
for _, v := range c.vals {
vals = append(vals, hashOf(v))
}
}
}
for _, i := range l.nodeIts {
t := i.it.tableID()
dir := t.dir.String()
if t.dir == quad.Any {
dir = t.tag
}
q = append(q, fmt.Sprintf("%s.%s_hash = %s.%s_hash", l.tableName, i.dir, t.table, dir))
}
for _, i := range l.nodeIts {
s, v := i.it.buildWhere()
q = append(q, s)
vals = append(vals, v...)
}
query := strings.Join(q, " AND ")
return query, vals
}
func (l *SQLLinkIterator) tableID() tagDir {
return tagDir{
dir: quad.Any,
table: l.tableName,
}
}
func (l *SQLLinkIterator) buildSQL(next bool, val graph.Value) (string, []string) {
query := "SELECT "
t := []string{
fmt.Sprintf("%s.subject", l.tableName),
fmt.Sprintf("%s.predicate", l.tableName),
fmt.Sprintf("%s.object", l.tableName),
fmt.Sprintf("%s.label", l.tableName),
}
for _, v := range l.getTags() {
t = append(t, v.String())
}
query += strings.Join(t, ", ")
query += " FROM "
t = []string{}
var values []string
for _, k := range l.getTables() {
values = append(values, k.values...)
t = append(t, fmt.Sprintf("%s as %s", k.table, k.name))
}
query += strings.Join(t, ", ")
constraint, wherevalues := l.buildWhere()
if constraint != "" {
query += " WHERE "
}
values = append(values, wherevalues...)
if !next {
v := val.(quad.Quad)
if constraint != "" {
constraint += " AND "
} else {
constraint += " WHERE "
}
t = []string{
fmt.Sprintf("%s.subject_hash = ?", l.tableName),
fmt.Sprintf("%s.predicate_hash = ?", l.tableName),
fmt.Sprintf("%s.object_hash = ?", l.tableName),
fmt.Sprintf("%s.label_hash = ?", l.tableName),
}
constraint += strings.Join(t, " AND ")
values = append(values, hashOf(v.Subject))
values = append(values, hashOf(v.Predicate))
values = append(values, hashOf(v.Object))
values = append(values, hashOf(v.Label))
}
query += constraint
query += ";"
if glog.V(4) {
dstr := query
for i := 1; i <= len(values); i++ {
dstr = strings.Replace(dstr, "?", fmt.Sprintf("'%s'", values[i-1]), 1)
}
glog.V(4).Infoln(dstr)
}
return query, values
}
func (l *SQLLinkIterator) sameTopResult(target []string, test []string) bool {
return target[0] == test[0] && target[1] == test[1] && target[2] == test[2] && target[3] == test[3]
}