package api import ( "bytes" "encoding/json" "errors" "fmt" "io/ioutil" "log" mathrand "math/rand" "net/http" "net/http/httptest" "net/url" "os" "strings" "testing" "git.example.com/example/goserv/internal/db" "git.rootprojects.org/root/keypairs" "git.rootprojects.org/root/keypairs/keyfetch" "github.com/go-chi/chi" ) var srv *httptest.Server var testKey keypairs.PrivateKey var testPub keypairs.PublicKey var testWhitelist keyfetch.Whitelist func init() { // In tests it's nice to get the same "random" values, every time RandReader = testReader{} mathrand.Seed(0) } func TestMain(m *testing.M) { connStr := needsTestDB(m) if strings.Contains(connStr, "@localhost/") || strings.Contains(connStr, "@localhost:") { connStr += "?sslmode=disable" } else { connStr += "?sslmode=required" } if err := db.Init(connStr); nil != err { log.Fatal("db connection error", err) return } if err := db.DropAllTables(db.PleaseDoubleCheckTheDatabaseURLDontDropProd(connStr)); nil != err { log.Fatal(err) } if err := db.Init(connStr); nil != err { log.Fatal("db connection error", err) return } var err error testKey = keypairs.NewDefaultPrivateKey() testPub = keypairs.NewPublicKey(testKey.Public()) r := chi.NewRouter() srv = httptest.NewServer(Init(testPub, r)) testWhitelist, err = keyfetch.NewWhitelist(nil, []string{srv.URL}) if nil != err { log.Fatal("bad whitelist", err) return } os.Exit(m.Run()) } // public APIs func Test_Public_Ping(t *testing.T) { if err := testPing("public"); nil != err { t.Fatal(err) } } // test types type testReader struct{} func (testReader) Read(p []byte) (n int, err error) { return mathrand.Read(p) } func testPing(which string) error { urlstr := fmt.Sprintf("/api/%s/ping", which) res, err := testReq("GET", urlstr, "", nil, 200) if nil != err { return err } data := map[string]interface{}{} if err := json.NewDecoder(res.Body).Decode(&data); nil != err { return err } if success, ok := data["success"].(bool); !ok || !success { log.Printf("Bad Response\n\tURL:%s\n\tBody:\n%#v", urlstr, data) return errors.New("bad response: missing success") } if ppid, _ := data["ppid"].(string); "" != ppid { return fmt.Errorf("the effective user ID isn't what it should be: %q != %q", ppid, "") } return nil } func testReq(method, pathname string, jwt string, payload []byte, expectedStatus int) (*http.Response, error) { client := srv.Client() urlstr, _ := url.Parse(srv.URL + pathname) if "" == method { method = "GET" } req := &http.Request{ Method: method, URL: urlstr, Body: ioutil.NopCloser(bytes.NewReader(payload)), Header: http.Header{}, } if len(jwt) > 0 { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", jwt)) } res, err := client.Do(req) if nil != err { return nil, err } if expectedStatus > 0 { if expectedStatus != res.StatusCode { data, _ := ioutil.ReadAll(res.Body) log.Printf("Bad Response: %d\n\tURL:%s\n\tBody:\n%s", res.StatusCode, urlstr, string(data)) return nil, fmt.Errorf("bad status code: %d", res.StatusCode) } } return res, nil } func needsTestDB(m *testing.M) string { connStr := os.Getenv("TEST_DATABASE_URL") if "" == connStr { log.Fatal(`no connection string defined You must set TEST_DATABASE_URL to run db tests. You may find this helpful: psql 'postgres://postgres:postgres@localhost:5432/postgres' DROP DATABASE IF EXISTS postgres_test; CREATE DATABASE postgres_test; \q Then your test database URL will be export TEST_DATABASE_URL=postgres://postgres:postgres@localhost:5432/postgres_test `) } return connStr }