Update xorm to latest version and fix correct `user` table referencing in sql (#4473) (#4483)

This commit is contained in:
Lauris BH 2018-07-20 21:48:15 +03:00 committed by GitHub
parent 88d791013b
commit 8a639ade58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 1718 additions and 1102 deletions

10
Gopkg.lock generated
View File

@ -299,12 +299,14 @@
[[projects]]
name = "github.com/go-xorm/builder"
packages = ["."]
revision = "488224409dd8aa2ce7a5baf8d10d55764a913738"
revision = "dc8bf48f58fab2b4da338ffd25191905fd741b8f"
version = "v0.3.0"
[[projects]]
name = "github.com/go-xorm/core"
packages = ["."]
revision = "cb1d0ca71f42d3ee1bf4aba7daa16099bc31a7e9"
revision = "c10e21e7e1cec20e09398f2dfae385e58c8df555"
version = "v0.6.0"
[[projects]]
name = "github.com/go-xorm/tidb"
@ -314,7 +316,7 @@
[[projects]]
name = "github.com/go-xorm/xorm"
packages = ["."]
revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03"
revision = "ad69f7d8f0861a29438154bb0a20b60501298480"
[[projects]]
branch = "master"
@ -873,6 +875,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "036b8c882671cf8d2c5e2fdbe53b1bdfbd39f7ebd7765bd50276c7c4ecf16687"
inputs-digest = "3b587a036aaf09514228ead18e7fd93e9ee1d14d4e56715bb2f197d5f27259d1"
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -38,7 +38,7 @@ ignored = ["google.golang.org/appengine*"]
[[override]]
name = "github.com/go-xorm/xorm"
#version = "0.6.5"
revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03"
revision = "ad69f7d8f0861a29438154bb0a20b60501298480"
[[override]]
name = "github.com/gorilla/mux"

View File

@ -1283,7 +1283,7 @@ func getParticipantsByIssueID(e Engine, issueID int64) ([]*User, error) {
And("`comment`.type = ?", CommentTypeComment).
And("`user`.is_active = ?", true).
And("`user`.prohibit_login = ?", false).
Join("INNER", "user", "`user`.id = `comment`.poster_id").
Join("INNER", "`user`", "`user`.id = `comment`.poster_id").
Distinct("poster_id").
Find(&userIDs); err != nil {
return nil, fmt.Errorf("get poster IDs: %v", err)

View File

@ -166,7 +166,7 @@ func (issues IssueList) loadAssignees(e Engine) error {
var assignees = make(map[int64][]*User, len(issues))
rows, err := e.Table("issue_assignees").
Join("INNER", "user", "`user`.id = `issue_assignees`.assignee_id").
Join("INNER", "`user`", "`user`.id = `issue_assignees`.assignee_id").
In("`issue_assignees`.issue_id", issues.getIssueIDs()).
Rows(new(AssigneeIssue))
if err != nil {

View File

@ -67,7 +67,7 @@ func getIssueWatchers(e Engine, issueID int64) (watches []*IssueWatch, err error
Where("`issue_watch`.issue_id = ?", issueID).
And("`user`.is_active = ?", true).
And("`user`.prohibit_login = ?", false).
Join("INNER", "user", "`user`.id = `issue_watch`.user_id").
Join("INNER", "`user`", "`user`.id = `issue_watch`.user_id").
Find(&watches)
return
}

View File

@ -383,7 +383,7 @@ func GetOwnedOrgsByUserIDDesc(userID int64, desc string) ([]*User, error) {
func GetOrgUsersByUserID(uid int64, all bool) ([]*OrgUser, error) {
ous := make([]*OrgUser, 0, 10)
sess := x.
Join("LEFT", "user", "`org_user`.org_id=`user`.id").
Join("LEFT", "`user`", "`org_user`.org_id=`user`.id").
Where("`org_user`.uid=?", uid)
if !all {
// Only show public organizations
@ -575,7 +575,7 @@ func (org *User) getUserTeams(e Engine, userID int64, cols ...string) ([]*Team,
return teams, e.
Where("`team_user`.org_id = ?", org.ID).
Join("INNER", "team_user", "`team_user`.team_id = team.id").
Join("INNER", "user", "`user`.id=team_user.uid").
Join("INNER", "`user`", "`user`.id=team_user.uid").
And("`team_user`.uid = ?", userID).
Asc("`user`.name").
Cols(cols...).

View File

@ -1954,7 +1954,7 @@ func DeleteRepository(doer *User, uid, repoID int64) error {
func GetRepositoryByOwnerAndName(ownerName, repoName string) (*Repository, error) {
var repo Repository
has, err := x.Select("repository.*").
Join("INNER", "user", "`user`.id = repository.owner_id").
Join("INNER", "`user`", "`user`.id = repository.owner_id").
Where("repository.lower_name = ?", strings.ToLower(repoName)).
And("`user`.lower_name = ?", strings.ToLower(ownerName)).
Get(&repo)

View File

@ -54,7 +54,7 @@ func getWatchers(e Engine, repoID int64) ([]*Watch, error) {
return watches, e.Where("`watch`.repo_id=?", repoID).
And("`user`.is_active=?", true).
And("`user`.prohibit_login=?", false).
Join("INNER", "user", "`user`.id = `watch`.user_id").
Join("INNER", "`user`", "`user`.id = `watch`.user_id").
Find(&watches)
}

View File

@ -374,9 +374,9 @@ func (u *User) GetFollowers(page int) ([]*User, error) {
Limit(ItemsPerPage, (page-1)*ItemsPerPage).
Where("follow.follow_id=?", u.ID)
if setting.UsePostgreSQL {
sess = sess.Join("LEFT", "follow", `"user".id=follow.user_id`)
sess = sess.Join("LEFT", "follow", "`user`.id=follow.user_id")
} else {
sess = sess.Join("LEFT", "follow", "user.id=follow.user_id")
sess = sess.Join("LEFT", "follow", "`user`.id=follow.user_id")
}
return users, sess.Find(&users)
}
@ -393,9 +393,9 @@ func (u *User) GetFollowing(page int) ([]*User, error) {
Limit(ItemsPerPage, (page-1)*ItemsPerPage).
Where("follow.user_id=?", u.ID)
if setting.UsePostgreSQL {
sess = sess.Join("LEFT", "follow", `"user".id=follow.follow_id`)
sess = sess.Join("LEFT", "follow", "`user`.id=follow.follow_id")
} else {
sess = sess.Join("LEFT", "follow", "user.id=follow.follow_id")
sess = sess.Join("LEFT", "follow", "`user`.id=follow.follow_id")
}
return users, sess.Find(&users)
}

View File

@ -4,6 +4,10 @@
package builder
import (
"fmt"
)
type optype byte
const (
@ -29,6 +33,9 @@ type Builder struct {
joins []join
inserts Eq
updates []Eq
orderBy string
groupBy string
having string
}
// Select creates a select Builder
@ -67,6 +74,11 @@ func (b *Builder) From(tableName string) *Builder {
return b
}
// TableName returns the table name
func (b *Builder) TableName() string {
return b.tableName
}
// Into sets insert table name
func (b *Builder) Into(tableName string) *Builder {
b.tableName = tableName
@ -178,6 +190,33 @@ func (b *Builder) ToSQL() (string, []interface{}, error) {
return w.writer.String(), w.args, nil
}
// ConvertPlaceholder replaces ? to $1, $2 ... or :1, :2 ... according prefix
func ConvertPlaceholder(sql, prefix string) (string, error) {
buf := StringBuilder{}
var j, start = 0, 0
for i := 0; i < len(sql); i++ {
if sql[i] == '?' {
_, err := buf.WriteString(sql[start:i])
if err != nil {
return "", err
}
start = i + 1
_, err = buf.WriteString(prefix)
if err != nil {
return "", err
}
j = j + 1
_, err = buf.WriteString(fmt.Sprintf("%d", j))
if err != nil {
return "", err
}
}
}
return buf.String(), nil
}
// ToSQL convert a builder or condtions to SQL and args
func ToSQL(cond interface{}) (string, []interface{}, error) {
switch cond.(type) {

View File

@ -15,7 +15,7 @@ func (b *Builder) insertWriteTo(w Writer) error {
return errors.New("no table indicated")
}
if len(b.inserts) <= 0 {
return errors.New("no column to be update")
return errors.New("no column to be insert")
}
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil {
@ -26,7 +26,9 @@ func (b *Builder) insertWriteTo(w Writer) error {
var bs []byte
var valBuffer = bytes.NewBuffer(bs)
var i = 0
for col, value := range b.inserts {
for _, col := range b.inserts.sortedKeys() {
value := b.inserts[col]
fmt.Fprint(w, col)
if e, ok := value.(expr); ok {
fmt.Fprint(valBuffer, e.sql)

View File

@ -34,24 +34,65 @@ func (b *Builder) selectWriteTo(w Writer) error {
}
}
if _, err := fmt.Fprintf(w, " FROM %s", b.tableName); err != nil {
if _, err := fmt.Fprint(w, " FROM ", b.tableName); err != nil {
return err
}
for _, v := range b.joins {
fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable)
if _, err := fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable); err != nil {
return err
}
if err := v.joinCond.WriteTo(w); err != nil {
return err
}
}
if !b.cond.IsValid() {
return nil
if b.cond.IsValid() {
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
if err := b.cond.WriteTo(w); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
if len(b.groupBy) > 0 {
if _, err := fmt.Fprint(w, " GROUP BY ", b.groupBy); err != nil {
return err
}
}
return b.cond.WriteTo(w)
if len(b.having) > 0 {
if _, err := fmt.Fprint(w, " HAVING ", b.having); err != nil {
return err
}
}
if len(b.orderBy) > 0 {
if _, err := fmt.Fprint(w, " ORDER BY ", b.orderBy); err != nil {
return err
}
}
return nil
}
// OrderBy orderBy SQL
func (b *Builder) OrderBy(orderBy string) *Builder {
b.orderBy = orderBy
return b
}
// GroupBy groupby SQL
func (b *Builder) GroupBy(groupby string) *Builder {
b.groupBy = groupby
return b
}
// Having having SQL
func (b *Builder) Having(having string) *Builder {
b.having = having
return b
}

View File

@ -5,7 +5,6 @@
package builder
import (
"bytes"
"io"
)
@ -19,15 +18,15 @@ var _ Writer = NewWriter()
// BytesWriter implments Writer and save SQL in bytes.Buffer
type BytesWriter struct {
writer *bytes.Buffer
buffer []byte
writer *StringBuilder
args []interface{}
}
// NewWriter creates a new string writer
func NewWriter() *BytesWriter {
w := &BytesWriter{}
w.writer = bytes.NewBuffer(w.buffer)
w := &BytesWriter{
writer: &StringBuilder{},
}
return w
}

View File

@ -10,7 +10,13 @@ import "fmt"
func WriteMap(w Writer, data map[string]interface{}, op string) error {
var args = make([]interface{}, 0, len(data))
var i = 0
for k, v := range data {
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
for _, k := range keys {
v := data[k]
switch v.(type) {
case expr:
if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil {

View File

@ -4,7 +4,10 @@
package builder
import "fmt"
import (
"fmt"
"sort"
)
// Incr implements a type used by Eq
type Incr int
@ -19,7 +22,8 @@ var _ Cond = Eq{}
func (eq Eq) opWriteTo(op string, w Writer) error {
var i = 0
for k, v := range eq {
for _, k := range eq.sortedKeys() {
v := eq[k]
switch v.(type) {
case []int, []int64, []string, []int32, []int16, []int8, []uint, []uint64, []uint32, []uint16, []interface{}:
if err := In(k, v).WriteTo(w); err != nil {
@ -94,3 +98,15 @@ func (eq Eq) Or(conds ...Cond) Cond {
func (eq Eq) IsValid() bool {
return len(eq) > 0
}
// sortedKeys returns all keys of this Eq sorted with sort.Strings.
// It is used internally for consistent ordering when generating
// SQL, see https://github.com/go-xorm/builder/issues/10
func (eq Eq) sortedKeys() []string {
keys := make([]string, 0, len(eq))
for key := range eq {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}

View File

@ -16,7 +16,7 @@ func (like Like) WriteTo(w Writer) error {
if _, err := fmt.Fprintf(w, "%s LIKE ?", like[0]); err != nil {
return err
}
// FIXME: if use other regular express, this will be failed. but for compitable, keep this
// FIXME: if use other regular express, this will be failed. but for compatible, keep this
if like[1][0] == '%' || like[1][len(like[1])-1] == '%' {
w.Append(like[1])
} else {

View File

@ -4,7 +4,10 @@
package builder
import "fmt"
import (
"fmt"
"sort"
)
// Neq defines not equal conditions
type Neq map[string]interface{}
@ -15,7 +18,8 @@ var _ Cond = Neq{}
func (neq Neq) WriteTo(w Writer) error {
var args = make([]interface{}, 0, len(neq))
var i = 0
for k, v := range neq {
for _, k := range neq.sortedKeys() {
v := neq[k]
switch v.(type) {
case []int, []int64, []string, []int32, []int16, []int8:
if err := NotIn(k, v).WriteTo(w); err != nil {
@ -76,3 +80,15 @@ func (neq Neq) Or(conds ...Cond) Cond {
func (neq Neq) IsValid() bool {
return len(neq) > 0
}
// sortedKeys returns all keys of this Neq sorted with sort.Strings.
// It is used internally for consistent ordering when generating
// SQL, see https://github.com/go-xorm/builder/issues/10
func (neq Neq) sortedKeys() []string {
keys := make([]string, 0, len(neq))
for key := range neq {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}

View File

@ -21,6 +21,18 @@ func (not Not) WriteTo(w Writer) error {
if _, err := fmt.Fprint(w, "("); err != nil {
return err
}
case Eq:
if len(not[0].(Eq)) > 1 {
if _, err := fmt.Fprint(w, "("); err != nil {
return err
}
}
case Neq:
if len(not[0].(Neq)) > 1 {
if _, err := fmt.Fprint(w, "("); err != nil {
return err
}
}
}
if err := not[0].WriteTo(w); err != nil {
@ -32,6 +44,18 @@ func (not Not) WriteTo(w Writer) error {
if _, err := fmt.Fprint(w, ")"); err != nil {
return err
}
case Eq:
if len(not[0].(Eq)) > 1 {
if _, err := fmt.Fprint(w, ")"); err != nil {
return err
}
}
case Neq:
if len(not[0].(Neq)) > 1 {
if _, err := fmt.Fprint(w, ")"); err != nil {
return err
}
}
}
return nil

119
vendor/github.com/go-xorm/builder/strings_builder.go generated vendored Normal file
View File

@ -0,0 +1,119 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
"unicode/utf8"
"unsafe"
)
// A StringBuilder is used to efficiently build a string using Write methods.
// It minimizes memory copying. The zero value is ready to use.
// Do not copy a non-zero Builder.
type StringBuilder struct {
addr *StringBuilder // of receiver, to detect copies by value
buf []byte
}
// noescape hides a pointer from escape analysis. noescape is
// the identity function but escape analysis doesn't think the
// output depends on the input. noescape is inlined and currently
// compiles down to zero instructions.
// USE CAREFULLY!
// This was copied from the runtime; see issues 23382 and 7921.
//go:nosplit
func noescape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p)
return unsafe.Pointer(x ^ 0)
}
func (b *StringBuilder) copyCheck() {
if b.addr == nil {
// This hack works around a failing of Go's escape analysis
// that was causing b to escape and be heap allocated.
// See issue 23382.
// TODO: once issue 7921 is fixed, this should be reverted to
// just "b.addr = b".
b.addr = (*StringBuilder)(noescape(unsafe.Pointer(b)))
} else if b.addr != b {
panic("strings: illegal use of non-zero Builder copied by value")
}
}
// String returns the accumulated string.
func (b *StringBuilder) String() string {
return *(*string)(unsafe.Pointer(&b.buf))
}
// Len returns the number of accumulated bytes; b.Len() == len(b.String()).
func (b *StringBuilder) Len() int { return len(b.buf) }
// Reset resets the Builder to be empty.
func (b *StringBuilder) Reset() {
b.addr = nil
b.buf = nil
}
// grow copies the buffer to a new, larger buffer so that there are at least n
// bytes of capacity beyond len(b.buf).
func (b *StringBuilder) grow(n int) {
buf := make([]byte, len(b.buf), 2*cap(b.buf)+n)
copy(buf, b.buf)
b.buf = buf
}
// Grow grows b's capacity, if necessary, to guarantee space for
// another n bytes. After Grow(n), at least n bytes can be written to b
// without another allocation. If n is negative, Grow panics.
func (b *StringBuilder) Grow(n int) {
b.copyCheck()
if n < 0 {
panic("strings.Builder.Grow: negative count")
}
if cap(b.buf)-len(b.buf) < n {
b.grow(n)
}
}
// Write appends the contents of p to b's buffer.
// Write always returns len(p), nil.
func (b *StringBuilder) Write(p []byte) (int, error) {
b.copyCheck()
b.buf = append(b.buf, p...)
return len(p), nil
}
// WriteByte appends the byte c to b's buffer.
// The returned error is always nil.
func (b *StringBuilder) WriteByte(c byte) error {
b.copyCheck()
b.buf = append(b.buf, c)
return nil
}
// WriteRune appends the UTF-8 encoding of Unicode code point r to b's buffer.
// It returns the length of r and a nil error.
func (b *StringBuilder) WriteRune(r rune) (int, error) {
b.copyCheck()
if r < utf8.RuneSelf {
b.buf = append(b.buf, byte(r))
return 1, nil
}
l := len(b.buf)
if cap(b.buf)-l < utf8.UTFMax {
b.grow(utf8.UTFMax)
}
n := utf8.EncodeRune(b.buf[l:l+utf8.UTFMax], r)
b.buf = b.buf[:l+n]
return n, nil
}
// WriteString appends the contents of s to b's buffer.
// It returns the length of s and a nil error.
func (b *StringBuilder) WriteString(s string) (int, error) {
b.copyCheck()
b.buf = append(b.buf, s...)
return len(s), nil
}

View File

@ -147,12 +147,12 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
}
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1])
} else {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
}
}
if !fieldValue.IsValid() {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
}
return &fieldValue, nil

57
vendor/github.com/go-xorm/core/db.go generated vendored
View File

@ -7,6 +7,11 @@ import (
"fmt"
"reflect"
"regexp"
"sync"
)
var (
DefaultCacheSize = 200
)
func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
@ -58,9 +63,16 @@ func StructToSlice(query string, st interface{}) (string, []interface{}, error)
return query, args, nil
}
type cacheStruct struct {
value reflect.Value
idx int
}
type DB struct {
*sql.DB
Mapper IMapper
Mapper IMapper
reflectCache map[reflect.Type]*cacheStruct
reflectCacheMutex sync.RWMutex
}
func Open(driverName, dataSourceName string) (*DB, error) {
@ -68,11 +80,32 @@ func Open(driverName, dataSourceName string) (*DB, error) {
if err != nil {
return nil, err
}
return &DB{db, NewCacheMapper(&SnakeMapper{})}, nil
return &DB{
DB: db,
Mapper: NewCacheMapper(&SnakeMapper{}),
reflectCache: make(map[reflect.Type]*cacheStruct),
}, nil
}
func FromDB(db *sql.DB) *DB {
return &DB{db, NewCacheMapper(&SnakeMapper{})}
return &DB{
DB: db,
Mapper: NewCacheMapper(&SnakeMapper{}),
reflectCache: make(map[reflect.Type]*cacheStruct),
}
}
func (db *DB) reflectNew(typ reflect.Type) reflect.Value {
db.reflectCacheMutex.Lock()
defer db.reflectCacheMutex.Unlock()
cs, ok := db.reflectCache[typ]
if !ok || cs.idx+1 > DefaultCacheSize-1 {
cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0}
db.reflectCache[typ] = cs
} else {
cs.idx = cs.idx + 1
}
return cs.value.Index(cs.idx).Addr()
}
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
@ -83,7 +116,7 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
}
return nil, err
}
return &Rows{rows, db.Mapper}, nil
return &Rows{rows, db}, nil
}
func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
@ -128,8 +161,8 @@ func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
type Stmt struct {
*sql.Stmt
Mapper IMapper
names map[string]int
db *DB
names map[string]int
}
func (db *DB) Prepare(query string) (*Stmt, error) {
@ -145,7 +178,7 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
if err != nil {
return nil, err
}
return &Stmt{stmt, db.Mapper, names}, nil
return &Stmt{stmt, db, names}, nil
}
func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
@ -179,7 +212,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
return &Rows{rows, s.Mapper}, nil
return &Rows{rows, s.db}, nil
}
func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
@ -274,7 +307,7 @@ func (EmptyScanner) Scan(src interface{}) error {
type Tx struct {
*sql.Tx
Mapper IMapper
db *DB
}
func (db *DB) Begin() (*Tx, error) {
@ -282,7 +315,7 @@ func (db *DB) Begin() (*Tx, error) {
if err != nil {
return nil, err
}
return &Tx{tx, db.Mapper}, nil
return &Tx{tx, db}, nil
}
func (tx *Tx) Prepare(query string) (*Stmt, error) {
@ -298,7 +331,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
if err != nil {
return nil, err
}
return &Stmt{stmt, tx.Mapper, names}, nil
return &Stmt{stmt, tx.db, names}, nil
}
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
@ -327,7 +360,7 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
return &Rows{rows, tx.Mapper}, nil
return &Rows{rows, tx.db}, nil
}
func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {

View File

@ -74,6 +74,7 @@ type Dialect interface {
GetIndexes(tableName string) (map[string]*Index, error)
Filters() []Filter
SetParams(params map[string]string)
}
func OpenDialect(dialect Dialect) (*DB, error) {
@ -148,7 +149,8 @@ func (db *Base) SupportDropIfExists() bool {
}
func (db *Base) DropTableSql(tableName string) string {
return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName)
quote := db.dialect.Quote
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName))
}
func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) {
@ -289,6 +291,9 @@ func (b *Base) LogSQL(sql string, args []interface{}) {
}
}
func (b *Base) SetParams(params map[string]string) {
}
var (
dialects = map[string]func() Dialect{}
)

View File

@ -37,9 +37,9 @@ func (q *Quoter) Quote(content string) string {
func (i *IdFilter) Do(sql string, dialect Dialect, table *Table) string {
quoter := NewQuoter(dialect)
if table != nil && len(table.PrimaryKeys) == 1 {
sql = strings.Replace(sql, "`(id)`", quoter.Quote(table.PrimaryKeys[0]), -1)
sql = strings.Replace(sql, quoter.Quote("(id)"), quoter.Quote(table.PrimaryKeys[0]), -1)
return strings.Replace(sql, "(id)", quoter.Quote(table.PrimaryKeys[0]), -1)
sql = strings.Replace(sql, " `(id)` ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1)
sql = strings.Replace(sql, " "+quoter.Quote("(id)")+" ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1)
return strings.Replace(sql, " (id) ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1)
}
return sql
}

View File

@ -22,6 +22,8 @@ type Index struct {
func (index *Index) XName(tableName string) string {
if !strings.HasPrefix(index.Name, "UQE_") &&
!strings.HasPrefix(index.Name, "IDX_") {
tableName = strings.Replace(tableName, `"`, "", -1)
tableName = strings.Replace(tableName, `.`, "_", -1)
if index.Type == UniqueType {
return fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
}

View File

@ -9,7 +9,7 @@ import (
type Rows struct {
*sql.Rows
Mapper IMapper
db *DB
}
func (rs *Rows) ToMapString() ([]map[string]string, error) {
@ -105,7 +105,7 @@ func (rs *Rows) ScanStructByName(dest interface{}) error {
newDest := make([]interface{}, len(cols))
var v EmptyScanner
for j, name := range cols {
f := fieldByName(vv.Elem(), rs.Mapper.Table2Obj(name))
f := fieldByName(vv.Elem(), rs.db.Mapper.Table2Obj(name))
if f.IsValid() {
newDest[j] = f.Addr().Interface()
} else {
@ -116,36 +116,6 @@ func (rs *Rows) ScanStructByName(dest interface{}) error {
return rs.Rows.Scan(newDest...)
}
type cacheStruct struct {
value reflect.Value
idx int
}
var (
reflectCache = make(map[reflect.Type]*cacheStruct)
reflectCacheMutex sync.RWMutex
)
func ReflectNew(typ reflect.Type) reflect.Value {
reflectCacheMutex.RLock()
cs, ok := reflectCache[typ]
reflectCacheMutex.RUnlock()
const newSize = 200
if !ok || cs.idx+1 > newSize-1 {
cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), newSize, newSize), 0}
reflectCacheMutex.Lock()
reflectCache[typ] = cs
reflectCacheMutex.Unlock()
} else {
reflectCacheMutex.Lock()
cs.idx = cs.idx + 1
reflectCacheMutex.Unlock()
}
return cs.value.Index(cs.idx).Addr()
}
// scan data to a slice's pointer, slice's length should equal to columns' number
func (rs *Rows) ScanSlice(dest interface{}) error {
vv := reflect.ValueOf(dest)
@ -197,9 +167,7 @@ func (rs *Rows) ScanMap(dest interface{}) error {
vvv := vv.Elem()
for i, _ := range cols {
newDest[i] = ReflectNew(vvv.Type().Elem()).Interface()
//v := reflect.New(vvv.Type().Elem())
//newDest[i] = v.Interface()
newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface()
}
err = rs.Rows.Scan(newDest...)
@ -215,32 +183,6 @@ func (rs *Rows) ScanMap(dest interface{}) error {
return nil
}
/*func (rs *Rows) ScanMap(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return errors.New("dest should be a map's pointer")
}
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
err = rs.ScanSlice(newDest)
if err != nil {
return err
}
vvv := vv.Elem()
for i, name := range cols {
vname := reflect.ValueOf(name)
vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
}
return nil
}*/
type Row struct {
rows *Rows
// One of these two will be non-nil:

View File

@ -49,7 +49,6 @@ func NewTable(name string, t reflect.Type) *Table {
}
func (table *Table) columnsByName(name string) []*Column {
n := len(name)
for k := range table.columnsMap {
@ -75,7 +74,6 @@ func (table *Table) GetColumn(name string) *Column {
}
func (table *Table) GetColumnIdx(name string, idx int) *Column {
cols := table.columnsByName(name)
if cols != nil && idx < len(cols) {

View File

@ -69,15 +69,18 @@ var (
Enum = "ENUM"
Set = "SET"
Char = "CHAR"
Varchar = "VARCHAR"
NVarchar = "NVARCHAR"
TinyText = "TINYTEXT"
Text = "TEXT"
Clob = "CLOB"
MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT"
Uuid = "UUID"
Char = "CHAR"
Varchar = "VARCHAR"
NVarchar = "NVARCHAR"
TinyText = "TINYTEXT"
Text = "TEXT"
NText = "NTEXT"
Clob = "CLOB"
MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT"
Uuid = "UUID"
UniqueIdentifier = "UNIQUEIDENTIFIER"
SysName = "SYSNAME"
Date = "DATE"
DateTime = "DATETIME"
@ -128,10 +131,12 @@ var (
NVarchar: TEXT_TYPE,
TinyText: TEXT_TYPE,
Text: TEXT_TYPE,
NText: TEXT_TYPE,
MediumText: TEXT_TYPE,
LongText: TEXT_TYPE,
Uuid: TEXT_TYPE,
Clob: TEXT_TYPE,
SysName: TEXT_TYPE,
Date: TIME_TYPE,
DateTime: TIME_TYPE,
@ -148,11 +153,12 @@ var (
Binary: BLOB_TYPE,
VarBinary: BLOB_TYPE,
TinyBlob: BLOB_TYPE,
Blob: BLOB_TYPE,
MediumBlob: BLOB_TYPE,
LongBlob: BLOB_TYPE,
Bytea: BLOB_TYPE,
TinyBlob: BLOB_TYPE,
Blob: BLOB_TYPE,
MediumBlob: BLOB_TYPE,
LongBlob: BLOB_TYPE,
Bytea: BLOB_TYPE,
UniqueIdentifier: BLOB_TYPE,
Bool: NUMERIC_TYPE,
@ -289,9 +295,9 @@ func SQLType2Type(st SQLType) reflect.Type {
return reflect.TypeOf(float32(1))
case Double:
return reflect.TypeOf(float64(1))
case Char, Varchar, NVarchar, TinyText, Text, MediumText, LongText, Enum, Set, Uuid, Clob:
case Char, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName:
return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary:
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier:
return reflect.TypeOf([]byte{})
case Bool:
return reflect.TypeOf(true)

View File

@ -172,12 +172,33 @@ type mysql struct {
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
rowFormat string
}
func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName)
}
func (db *mysql) SetParams(params map[string]string) {
rowFormat, ok := params["rowFormat"]
if ok {
var t = strings.ToUpper(rowFormat)
switch t {
case "COMPACT":
fallthrough
case "REDUNDANT":
fallthrough
case "DYNAMIC":
fallthrough
case "COMPRESSED":
db.rowFormat = t
break
default:
break
}
}
}
func (db *mysql) SqlType(c *core.Column) string {
var res string
switch t := c.SQLType.Name; t {
@ -487,6 +508,59 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
return indexes, nil
}
func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
sql += db.Quote(tableName)
sql += " ("
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(db)
} else {
sql += col.StringNoPk(db)
}
sql = strings.TrimSpace(sql)
if len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += db.Quote(strings.Join(pkList, db.Quote(",")))
sql += " ), "
}
sql = sql[:len(sql)-2]
}
sql += ")"
if storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if len(charset) == 0 {
charset = db.URI().Charset
} else if len(charset) > 0 {
sql += " DEFAULT CHARSET " + charset
}
if db.rowFormat != "" {
sql += " ROW_FORMAT=" + db.rowFormat
}
return sql
}
func (db *mysql) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}}
}

View File

@ -769,14 +769,21 @@ var (
DefaultPostgresSchema = "public"
)
const postgresPublicSchema = "public"
type postgres struct {
core.Base
schema string
}
func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
db.schema = DefaultPostgresSchema
return db.Base.Init(d, db, uri, drivername, dataSourceName)
err := db.Base.Init(d, db, uri, drivername, dataSourceName)
if err != nil {
return err
}
if db.Schema == "" {
db.Schema = DefaultPostgresSchema
}
return nil
}
func (db *postgres) SqlType(c *core.Column) string {
@ -873,32 +880,42 @@ func (db *postgres) IndexOnTable() bool {
}
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName}
if len(db.Schema) == 0 {
args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args
}
args := []interface{}{db.Schema, tableName, idxName}
return `SELECT indexname FROM pg_indexes ` +
`WHERE tablename = ? AND indexname = ?`, args
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
}
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
if len(db.Schema) == 0 {
args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
}
args := []interface{}{db.Schema, tableName}
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
}
/*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName, colName}
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
" AND column_name = ?", args
}*/
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SqlType(col))
if len(db.Schema) == 0 {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SqlType(col))
}
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s",
db.Schema, tableName, col.Name, db.SqlType(col))
}
func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
//var unique string
quote := db.Quote
idxName := index.Name
tableName = strings.Replace(tableName, `"`, "", -1)
tableName = strings.Replace(tableName, `.`, "_", -1)
if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") {
if index.Type == core.UniqueType {
@ -907,13 +924,21 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
}
if db.Uri.Schema != "" {
idxName = db.Uri.Schema + "." + idxName
}
return fmt.Sprintf("DROP INDEX %v", quote(idxName))
}
func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
args := []interface{}{tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2"
args := []interface{}{db.Schema, tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
" AND column_name = $3"
if len(db.Schema) == 0 {
args = []interface{}{tableName, colName}
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2"
}
db.LogSQL(query, args)
rows, err := db.DB().Query(query, args...)
@ -926,8 +951,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
}
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
// FIXME: the schema should be replaced by user custom's
args := []interface{}{tableName, db.schema}
args := []interface{}{tableName}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix ,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
@ -938,7 +962,15 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid