165 lines
3.6 KiB
Go
165 lines
3.6 KiB
Go
|
package api
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io/ioutil"
|
||
|
"log"
|
||
|
mathrand "math/rand"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"net/url"
|
||
|
"os"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
|
||
|
"git.example.com/example/goserv/internal/db"
|
||
|
"git.rootprojects.org/root/keypairs"
|
||
|
"git.rootprojects.org/root/keypairs/keyfetch"
|
||
|
|
||
|
"github.com/go-chi/chi"
|
||
|
)
|
||
|
|
||
|
var srv *httptest.Server
|
||
|
|
||
|
var testKey keypairs.PrivateKey
|
||
|
var testPub keypairs.PublicKey
|
||
|
var testWhitelist keyfetch.Whitelist
|
||
|
|
||
|
func init() {
|
||
|
// In tests it's nice to get the same "random" values, every time
|
||
|
RandReader = testReader{}
|
||
|
mathrand.Seed(0)
|
||
|
}
|
||
|
|
||
|
func TestMain(m *testing.M) {
|
||
|
connStr := needsTestDB(m)
|
||
|
if strings.Contains(connStr, "@localhost/") || strings.Contains(connStr, "@localhost:") {
|
||
|
connStr += "?sslmode=disable"
|
||
|
} else {
|
||
|
connStr += "?sslmode=required"
|
||
|
}
|
||
|
|
||
|
if err := db.Init(connStr); nil != err {
|
||
|
log.Fatal("db connection error", err)
|
||
|
return
|
||
|
}
|
||
|
if err := db.DropAllTables(db.PleaseDoubleCheckTheDatabaseURLDontDropProd(connStr)); nil != err {
|
||
|
log.Fatal(err)
|
||
|
}
|
||
|
if err := db.Init(connStr); nil != err {
|
||
|
log.Fatal("db connection error", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var err error
|
||
|
testKey = keypairs.NewDefaultPrivateKey()
|
||
|
testPub = keypairs.NewPublicKey(testKey.Public())
|
||
|
r := chi.NewRouter()
|
||
|
srv = httptest.NewServer(Init(testPub, r))
|
||
|
testWhitelist, err = keyfetch.NewWhitelist(nil, []string{srv.URL})
|
||
|
if nil != err {
|
||
|
log.Fatal("bad whitelist", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
os.Exit(m.Run())
|
||
|
}
|
||
|
|
||
|
// public APIs
|
||
|
|
||
|
func Test_Public_Ping(t *testing.T) {
|
||
|
if err := testPing("public"); nil != err {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// test types
|
||
|
|
||
|
type testReader struct{}
|
||
|
|
||
|
func (testReader) Read(p []byte) (n int, err error) {
|
||
|
return mathrand.Read(p)
|
||
|
}
|
||
|
|
||
|
func testPing(which string) error {
|
||
|
urlstr := fmt.Sprintf("/api/%s/ping", which)
|
||
|
res, err := testReq("GET", urlstr, "", nil, 200)
|
||
|
if nil != err {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
data := map[string]interface{}{}
|
||
|
if err := json.NewDecoder(res.Body).Decode(&data); nil != err {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if success, ok := data["success"].(bool); !ok || !success {
|
||
|
log.Printf("Bad Response\n\tURL:%s\n\tBody:\n%#v", urlstr, data)
|
||
|
return errors.New("bad response: missing success")
|
||
|
}
|
||
|
|
||
|
if ppid, _ := data["ppid"].(string); "" != ppid {
|
||
|
return fmt.Errorf("the effective user ID isn't what it should be: %q != %q", ppid, "")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func testReq(method, pathname string, jwt string, payload []byte, expectedStatus int) (*http.Response, error) {
|
||
|
client := srv.Client()
|
||
|
urlstr, _ := url.Parse(srv.URL + pathname)
|
||
|
|
||
|
if "" == method {
|
||
|
method = "GET"
|
||
|
}
|
||
|
|
||
|
req := &http.Request{
|
||
|
Method: method,
|
||
|
URL: urlstr,
|
||
|
Body: ioutil.NopCloser(bytes.NewReader(payload)),
|
||
|
Header: http.Header{},
|
||
|
}
|
||
|
|
||
|
if len(jwt) > 0 {
|
||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", jwt))
|
||
|
}
|
||
|
res, err := client.Do(req)
|
||
|
if nil != err {
|
||
|
return nil, err
|
||
|
}
|
||
|
if expectedStatus > 0 {
|
||
|
if expectedStatus != res.StatusCode {
|
||
|
data, _ := ioutil.ReadAll(res.Body)
|
||
|
log.Printf("Bad Response: %d\n\tURL:%s\n\tBody:\n%s", res.StatusCode, urlstr, string(data))
|
||
|
return nil, fmt.Errorf("bad status code: %d", res.StatusCode)
|
||
|
}
|
||
|
}
|
||
|
return res, nil
|
||
|
}
|
||
|
|
||
|
func needsTestDB(m *testing.M) string {
|
||
|
connStr := os.Getenv("TEST_DATABASE_URL")
|
||
|
if "" == connStr {
|
||
|
log.Fatal(`no connection string defined
|
||
|
|
||
|
You must set TEST_DATABASE_URL to run db tests.
|
||
|
|
||
|
You may find this helpful:
|
||
|
|
||
|
psql 'postgres://postgres:postgres@localhost:5432/postgres'
|
||
|
|
||
|
DROP DATABASE IF EXISTS postgres_test;
|
||
|
CREATE DATABASE postgres_test;
|
||
|
\q
|
||
|
|
||
|
Then your test database URL will be
|
||
|
|
||
|
export TEST_DATABASE_URL=postgres://postgres:postgres@localhost:5432/postgres_test
|
||
|
`)
|
||
|
}
|
||
|
return connStr
|
||
|
}
|