package testfixtures import ( "database/sql" "fmt" ) // SQLServer is the helper for SQL Server for this package. // SQL Server >= 2008 is required. type SQLServer struct { baseHelper tables []string } func (h *SQLServer) init(db *sql.DB) error { var err error h.tables, err = h.getTables(db) if err != nil { return err } return nil } func (*SQLServer) paramType() int { return paramTypeQuestion } func (*SQLServer) quoteKeyword(str string) string { return fmt.Sprintf("[%s]", str) } func (*SQLServer) databaseName(db *sql.DB) (dbname string) { db.QueryRow("SELECT DB_NAME()").Scan(&dbname) return } func (*SQLServer) getTables(db *sql.DB) ([]string, error) { rows, err := db.Query("SELECT table_name FROM information_schema.tables") if err != nil { return nil, err } tables := make([]string, 0) defer rows.Close() for rows.Next() { var table string rows.Scan(&table) tables = append(tables, table) } return tables, nil } func (*SQLServer) tableHasIdentityColumn(tx *sql.Tx, tableName string) bool { sql := ` SELECT COUNT(*) FROM SYS.IDENTITY_COLUMNS WHERE OBJECT_NAME(OBJECT_ID) = ? ` var count int tx.QueryRow(sql, tableName).Scan(&count) return count > 0 } func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) error { if h.tableHasIdentityColumn(tx, tableName) { defer tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName))) _, err := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", h.quoteKeyword(tableName))) if err != nil { return err } } return fn() } func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error { // ensure the triggers are re-enable after all defer func() { sql := "" for _, table := range h.tables { sql += fmt.Sprintf("ALTER TABLE %s WITH CHECK CHECK CONSTRAINT ALL;", h.quoteKeyword(table)) } if _, err := db.Exec(sql); err != nil { fmt.Printf("Error on re-enabling constraints: %v\n", err) } }() sql := "" for _, table := range h.tables { sql += fmt.Sprintf("ALTER TABLE %s NOCHECK CONSTRAINT ALL;", h.quoteKeyword(table)) } if _, err := db.Exec(sql); err != nil { return err } tx, err := db.Begin() if err != nil { return err } if err = loadFn(tx); err != nil { tx.Rollback() return err } return tx.Commit() }