122 lines
2.9 KiB
Go
122 lines
2.9 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"regexp"
|
|
"sort"
|
|
"time"
|
|
|
|
"git.example.com/example/goserv/assets/configfs"
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
// pq injects itself into sql as 'postgres'
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
// DB is a concurrency-safe db connection instance
|
|
var DB *sqlx.DB
|
|
var firstDBURL PleaseDoubleCheckTheDatabaseURLDontDropProd
|
|
|
|
// Init initializes the database
|
|
func Init(pgURL string) error {
|
|
// https://godoc.org/github.com/lib/pq
|
|
|
|
firstDBURL = PleaseDoubleCheckTheDatabaseURLDontDropProd(pgURL)
|
|
dbtype := "postgres"
|
|
|
|
ctx, done := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
|
|
defer done()
|
|
db, err := sql.Open(dbtype, pgURL)
|
|
if err := db.PingContext(ctx); nil != err {
|
|
return err
|
|
}
|
|
|
|
// basic stuff
|
|
f, err := configfs.Assets.Open("./postgres/init.sql")
|
|
if nil != err {
|
|
return err
|
|
}
|
|
sqlBytes, err := ioutil.ReadAll(f)
|
|
if nil != err {
|
|
return err
|
|
}
|
|
if _, err := db.ExecContext(ctx, string(sqlBytes)); nil != err {
|
|
return err
|
|
}
|
|
|
|
// project-specific stuff
|
|
f, err = configfs.Assets.Open("./postgres/tables.sql")
|
|
if nil != err {
|
|
return err
|
|
}
|
|
sqlBytes, err = ioutil.ReadAll(f)
|
|
if nil != err {
|
|
return err
|
|
}
|
|
if _, err := db.ExecContext(ctx, string(sqlBytes)); nil != err {
|
|
return err
|
|
}
|
|
|
|
DB = sqlx.NewDb(db, dbtype)
|
|
|
|
return nil
|
|
}
|
|
|
|
// PleaseDoubleCheckTheDatabaseURLDontDropProd is just a friendly,
|
|
// hopefully helpful reminder, not to only use this in test files,
|
|
// and to not drop the production database
|
|
type PleaseDoubleCheckTheDatabaseURLDontDropProd string
|
|
|
|
// DropAllTables runs drop.sql, which is intended only for tests
|
|
func DropAllTables(dbURL PleaseDoubleCheckTheDatabaseURLDontDropProd) error {
|
|
if err := CanDropAllTables(string(dbURL)); nil != err {
|
|
return err
|
|
}
|
|
|
|
// drop stuff
|
|
f, err := configfs.Assets.Open("./postgres/drop.sql")
|
|
if nil != err {
|
|
return err
|
|
}
|
|
sqlBytes, err := ioutil.ReadAll(f)
|
|
if nil != err {
|
|
return err
|
|
}
|
|
ctx, done := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second))
|
|
defer done()
|
|
if _, err := DB.ExecContext(ctx, string(sqlBytes)); nil != err {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CanDropAllTables returns an error if the dbURL does not contain the words "test" or
|
|
// "demo" at a letter boundary
|
|
func CanDropAllTables(dbURL string) error {
|
|
var isDemo bool
|
|
nonalpha := regexp.MustCompile(`[^a-zA-Z]`)
|
|
haystack := nonalpha.Split(dbURL, -1)
|
|
sort.Strings(haystack)
|
|
for _, needle := range []string{"test", "demo"} {
|
|
// the index to insert x if x is not present (it could be len(a))
|
|
// (meaning that it is the index at which it exists, if it exists)
|
|
i := sort.SearchStrings(haystack, needle)
|
|
if i < len(haystack) && haystack[i] == needle {
|
|
isDemo = true
|
|
break
|
|
}
|
|
}
|
|
if isDemo {
|
|
return nil
|
|
}
|
|
return fmt.Errorf(
|
|
"test and demo database URLs must contain the word 'test' or 'demo' "+
|
|
"separated by a non-alphabet character, such as /test2/db_demo1\n%q\n",
|
|
dbURL,
|
|
)
|
|
}
|