package db import ( "context" "database/sql" "fmt" "io/ioutil" "regexp" "sort" "time" "git.example.com/example/goserv/assets/configfs" "github.com/jmoiron/sqlx" // pq injects itself into sql as 'postgres' _ "github.com/lib/pq" ) // DB is a concurrency-safe db connection instance var DB *sqlx.DB var firstDBURL PleaseDoubleCheckTheDatabaseURLDontDropProd // Init initializes the database func Init(pgURL string) error { // https://godoc.org/github.com/lib/pq firstDBURL = PleaseDoubleCheckTheDatabaseURLDontDropProd(pgURL) dbtype := "postgres" ctx, done := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) defer done() db, err := sql.Open(dbtype, pgURL) if err := db.PingContext(ctx); nil != err { return err } // basic stuff f, err := configfs.Assets.Open("./postgres/init.sql") if nil != err { return err } sqlBytes, err := ioutil.ReadAll(f) if nil != err { return err } if _, err := db.ExecContext(ctx, string(sqlBytes)); nil != err { return err } // project-specific stuff f, err = configfs.Assets.Open("./postgres/tables.sql") if nil != err { return err } sqlBytes, err = ioutil.ReadAll(f) if nil != err { return err } if _, err := db.ExecContext(ctx, string(sqlBytes)); nil != err { return err } DB = sqlx.NewDb(db, dbtype) return nil } // PleaseDoubleCheckTheDatabaseURLDontDropProd is just a friendly, // hopefully helpful reminder, not to only use this in test files, // and to not drop the production database type PleaseDoubleCheckTheDatabaseURLDontDropProd string // DropAllTables runs drop.sql, which is intended only for tests func DropAllTables(dbURL PleaseDoubleCheckTheDatabaseURLDontDropProd) error { if err := CanDropAllTables(string(dbURL)); nil != err { return err } // drop stuff f, err := configfs.Assets.Open("./postgres/drop.sql") if nil != err { return err } sqlBytes, err := ioutil.ReadAll(f) if nil != err { return err } ctx, done := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second)) defer done() if _, err := DB.ExecContext(ctx, string(sqlBytes)); nil != err { return err } return nil } // CanDropAllTables returns an error if the dbURL does not contain the words "test" or // "demo" at a letter boundary func CanDropAllTables(dbURL string) error { var isDemo bool nonalpha := regexp.MustCompile(`[^a-zA-Z]`) haystack := nonalpha.Split(dbURL, -1) sort.Strings(haystack) for _, needle := range []string{"test", "demo"} { // the index to insert x if x is not present (it could be len(a)) // (meaning that it is the index at which it exists, if it exists) i := sort.SearchStrings(haystack, needle) if i < len(haystack) && haystack[i] == needle { isDemo = true break } } if isDemo { return nil } return fmt.Errorf( "test and demo database URLs must contain the word 'test' or 'demo' "+ "separated by a non-alphabet character, such as /test2/db_demo1\n%q\n", dbURL, ) }