From 190ee93da8a3d4ae27d93bbcffe6531aca59331c Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Fri, 10 Apr 2020 19:29:01 +0000 Subject: [PATCH] refactor --- .gitignore | 1 + go.mod | 5 + go.sum | 2 + mockid.go | 385 +---------------------------------- mockid/mockid.go | 460 ++++++++++++++++++++++++++++++++++++++++++ mockid/mockid_test.go | 13 ++ 6 files changed, 491 insertions(+), 375 deletions(-) create mode 100644 go.mod create mode 100644 go.sum create mode 100644 mockid/mockid.go create mode 100644 mockid/mockid_test.go diff --git a/.gitignore b/.gitignore index 8275cf7..2ef6f89 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +/public-jwks /go-mockid # ---> Go diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6f9f1a4 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.coolaj86.com/coolaj86/go-mockid + +go 1.12 + +require github.com/joho/godotenv v1.3.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ead7071 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= diff --git a/mockid.go b/mockid.go index c13125a..aece92b 100644 --- a/mockid.go +++ b/mockid.go @@ -1,46 +1,20 @@ package main import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/sha256" - "crypto/sha512" - "encoding/base64" "encoding/json" "flag" "fmt" - "io/ioutil" "log" - "math/big" "net/http" "net/url" "os" - "path/filepath" "strconv" - "strings" - "time" + + "git.coolaj86.com/coolaj86/go-mockid/mockid" + + _ "github.com/joho/godotenv/autoload" ) -type PrivateJWK struct { - PublicJWK - D string `json:"d"` -} -type PublicJWK struct { - Crv string `json:"crv"` - KeyID string `json:"kid,omitempty"` - Kty string `json:"kty,omitempty"` - X string `json:"x"` - Y string `json:"y"` -} - -var nonces map[string]int64 -var jwksPrefix string - -func init() { - nonces = make(map[string]int64) -} - func main() { done := make(chan bool) var port int @@ -52,17 +26,15 @@ func main() { "x": "ToL2HppsTESXQKvp7ED6NMgV4YnwbMeONexNry3KDNQ", "y": "Tt6Q3rxU37KAinUV9PLMlwosNy1t3Bf2VDg5q955AGc", } - jwk := &PrivateJWK{ - PublicJWK: PublicJWK{ + jwk := &mockid.PrivateJWK{ + PublicJWK: mockid.PublicJWK{ Crv: jwkm["crv"], X: jwkm["x"], Y: jwkm["y"], }, D: jwkm["d"], } - priv := parseKey(jwk) - pub := &priv.PublicKey - thumbprint := thumbprintKey(pub) + priv := mockid.ParseKey(jwk) portFlag := flag.Int("port", 0, "Port on which the HTTP server should run") urlFlag := flag.String("url", "", "Outward-facing address, such as https://example.com") @@ -86,6 +58,7 @@ func main() { host = "http://localhost:" + strconv.Itoa(port) } +var jwksPrefix string if nil != prefixFlag && "" != *prefixFlag { jwksPrefix = *prefixFlag } else { @@ -97,206 +70,7 @@ func main() { os.Exit(1) } - http.HandleFunc("/api/new-nonce", func(w http.ResponseWriter, r *http.Request) { - baseURL := getBaseURL(r) - /* - res.statusCode = 200; - res.setHeader("Cache-Control", "max-age=0, no-cache, no-store"); - // TODO - //res.setHeader("Date", "Sun, 10 Mar 2019 08:04:45 GMT"); - // is this the expiration of the nonce itself? methinks maybe so - //res.setHeader("Expires", "Sun, 10 Mar 2019 08:04:45 GMT"); - // TODO use one of the registered domains - //var indexUrl = "https://acme-staging-v02.api.letsencrypt.org/index" - */ - //var port = (state.config.ipc && state.config.ipc.port || state._ipc.port || undefined); - //var indexUrl = "http://localhost:" + port + "/index"; - indexUrl := baseURL + "/index"; - w.Header().Set("Link", "<" + indexUrl + ">;rel=\"index\""); - w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store"); - w.Header().Set("Pragma", "no-cache"); - //res.setHeader("Strict-Transport-Security", "max-age=604800"); - - w.Header().Set("X-Frame-Options", "DENY") - issueNonce(w, r) - }) - - http.HandleFunc("/api/new-account", requireNonce(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "Not Implemented", http.StatusNotImplemented) - })) - - http.HandleFunc("/api/jwks", func(w http.ResponseWriter, r *http.Request) { - log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path) - if "POST" != r.Method { - http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) - return - } - - tok := make(map[string]interface{}) - decoder := json.NewDecoder(r.Body) - err := decoder.Decode(&tok) - if nil != err { - http.Error(w, "Bad Request: invalid json", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // TODO better, JSON error messages - if _, ok := tok["d"]; ok { - http.Error(w, "Bad Request: private key", http.StatusBadRequest) - return - } - - kty, _ := tok["kty"].(string) - if "EC" != kty { - http.Error(w, "Bad Request: only EC keys are supported", http.StatusBadRequest) - return - } - - crv, ok := tok["crv"].(string) - if 5 != len(crv) || "P-" != crv[:2] { - http.Error(w, "Bad Request: bad curve", http.StatusBadRequest) - return - } - x, ok := tok["x"].(string) - if !ok { - http.Error(w, "Bad Request: missing 'x'", http.StatusBadRequest) - return - } - y, ok := tok["y"].(string) - if !ok { - http.Error(w, "Bad Request: missing 'y'", http.StatusBadRequest) - return - } - - // TODO RSA - thumbprintable := []byte( - fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, crv, x, y), - ) - - var thumb []byte - switch crv[2:] { - case "256": - hash := sha256.Sum256(thumbprintable) - thumb = hash[:] - case "384": - hash := sha512.Sum384(thumbprintable) - thumb = hash[:] - case "521": - hash := sha512.Sum512(thumbprintable) - thumb = hash[:] - default: - http.Error(w, "Bad Request: bad curve", http.StatusBadRequest) - return - } - - kid := base64.RawURLEncoding.EncodeToString(thumb) - if kid2, _ := tok["kid"].(string); "" != kid2 && kid != kid2 { - http.Error(w, "Bad Request: kid should be "+kid, http.StatusBadRequest) - return - } - - // TODO allow posting at the top-level? - // TODO support a group of keys by PPID - // (right now it's only by KID) - if !strings.HasPrefix(r.Host, strings.ToLower(kid)+".") { - http.Error(w, "Bad Request: prefix should be "+kid, http.StatusBadRequest) - return - } - - pub := []byte(fmt.Sprintf( - `{"crv":%q,"kid":%q,"kty":"EC","x":%q,"y":%q}`, crv, kid, x, y, - )) - err = ioutil.WriteFile( - filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json"), - pub, - 0644, - ) - if nil != err { - fmt.Println("can't write file") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - baseURL := getBaseURL(r) - w.Write([]byte(fmt.Sprintf( - `{ "iss":%q, "jwks_url":%q }`, baseURL+"/", baseURL+"/.well-known/jwks.json", - ))) - }) - - http.HandleFunc("/access_token", func(w http.ResponseWriter, r *http.Request) { - log.Printf("%s %s\n", r.Method, r.URL.Path) - _, _, token := genToken(getBaseURL(r), priv, r.URL.Query()) - fmt.Fprintf(w, token) - }) - - http.HandleFunc("/authorization_header", func(w http.ResponseWriter, r *http.Request) { - log.Printf("%s %s\n", r.Method, r.URL.Path) - - var header string - headers, _ := r.URL.Query()["header"] - if 0 == len(headers) { - header = "Authorization" - } else { - header = headers[0] - } - - var prefix string - prefixes, _ := r.URL.Query()["prefix"] - if 0 == len(prefixes) { - prefix = "Bearer " - } else { - prefix = prefixes[0] - } - - _, _, token := genToken(getBaseURL(r), priv, r.URL.Query()) - fmt.Fprintf(w, "%s: %s%s", header, prefix, token) - }) - - http.HandleFunc("/key.jwk.json", func(w http.ResponseWriter, r *http.Request) { - log.Printf("%s %s", r.Method, r.URL.Path) - fmt.Fprintf(w, `{ "kty": "EC" , "crv": %q , "d": %q , "x": %q , "y": %q , "ext": true , "key_ops": ["sign"] }`, jwk.Crv, jwk.D, jwk.X, jwk.Y) - }) - - http.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { - baseURL := getBaseURL(r) - log.Printf("%s %s\n", r.Method, r.URL.Path) - fmt.Fprintf(w, `{ "issuer": "%s", "jwks_uri": "%s/.well-known/jwks.json" }`, baseURL, baseURL) - }) - - http.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { - log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path) - parts := strings.Split(r.Host, ".") - kid := parts[0] - - b, err := ioutil.ReadFile(filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json")) - if nil != err { - //http.Error(w, "Not Found", http.StatusNotFound) - jwkstr := fmt.Sprintf( - `{ "keys": [ { "kty": "EC" , "crv": %q , "x": %q , "y": %q , "kid": %q , "ext": true , "key_ops": ["verify"] , "exp": %s } ] }`, - jwk.Crv, jwk.X, jwk.Y, thumbprint, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10), - ) - fmt.Println(jwkstr) - fmt.Fprintf(w, jwkstr) - return - } - - tok := &PublicJWK{} - err = json.Unmarshal(b, tok) - if nil != err { - // TODO delete the bad file? - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - jwkstr := fmt.Sprintf( - `{ "keys": [ { "kty": "EC", "crv": %q, "x": %q, "y": %q, "kid": %q,`+ - ` "ext": true, "key_ops": ["verify"], "exp": %s } ] }`, - tok.Crv, tok.X, tok.Y, tok.KeyID, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10), - ) - fmt.Println(jwkstr) - fmt.Fprintf(w, jwkstr) - }) + mockid.Route(jwksPrefix, priv, jwk) fs := http.FileServer(http.Dir("public")) http.Handle("/", fs) @@ -317,149 +91,10 @@ func main() { fmt.Printf("Private Key:\n\t%s\n", string(b)) b, _ = json.Marshal(jwk.PublicJWK) fmt.Printf("Public Key:\n\t%s\n", string(b)) - protected, payload, token := genToken(host, priv, url.Values{}) + protected, payload, token := mockid.GenToken(host, priv, url.Values{}) fmt.Printf("Protected (Header):\n\t%s\n", protected) fmt.Printf("Payload (Claims):\n\t%s\n", payload) fmt.Printf("Access Token:\n\t%s\n", token) <-done } - -func parseExp(exp string) (int, error) { - if "" == exp { - exp = "15m" - } - mult := 1 - switch exp[len(exp)-1] { - case 'w': - mult *= 7 - fallthrough - case 'd': - mult *= 24 - fallthrough - case 'h': - mult *= 60 - fallthrough - case 'm': - mult *= 60 - fallthrough - case 's': - // no fallthrough - default: - // could be 'k' or 'z', but we assume its empty - exp += "s" - } - - num, err := strconv.Atoi(exp[:len(exp)-1]) - if nil != err { - return 0, err - } - return num * mult, nil -} - -func genToken(host string, priv *ecdsa.PrivateKey, query url.Values) (string, string, string) { - thumbprint := thumbprintKey(&priv.PublicKey) - protected := fmt.Sprintf(`{"typ":"JWT","alg":"ES256","kid":"%s"}`, thumbprint) - protected64 := base64.RawURLEncoding.EncodeToString([]byte(protected)) - - exp, err := parseExp(query.Get("exp")) - if nil != err { - // cryptic error code - // TODO propagate error - exp = 422 - } - - payload := fmt.Sprintf( - `{"iss":"%s/","sub":"dummy","exp":%s}`, - host, strconv.FormatInt(time.Now().Add(time.Duration(exp)*time.Second).Unix(), 10), - ) - payload64 := base64.RawURLEncoding.EncodeToString([]byte(payload)) - - hash := sha256.Sum256([]byte(fmt.Sprintf(`%s.%s`, protected64, payload64))) - r, s, _ := ecdsa.Sign(rand.Reader, priv, hash[:]) - rb := r.Bytes() - for len(rb) < 32 { - rb = append([]byte{0}, rb...) - } - sb := s.Bytes() - for len(rb) < 32 { - sb = append([]byte{0}, sb...) - } - sig64 := base64.RawURLEncoding.EncodeToString(append(rb, sb...)) - token := fmt.Sprintf(`%s.%s.%s`, protected64, payload64, sig64) - return protected, payload, token -} - -func parseKey(jwk *PrivateJWK) *ecdsa.PrivateKey { - xb, _ := base64.RawURLEncoding.DecodeString(jwk.X) - xi := &big.Int{} - xi.SetBytes(xb) - yb, _ := base64.RawURLEncoding.DecodeString(jwk.Y) - yi := &big.Int{} - yi.SetBytes(yb) - pub := &ecdsa.PublicKey{ - Curve: elliptic.P256(), - X: xi, - Y: yi, - } - - db, _ := base64.RawURLEncoding.DecodeString(jwk.D) - di := &big.Int{} - di.SetBytes(db) - priv := &ecdsa.PrivateKey{ - PublicKey: *pub, - D: di, - } - return priv -} - -func thumbprintKey(pub *ecdsa.PublicKey) string { - minpub := []byte(fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, "P-256", pub.X, pub.Y)) - sha := sha256.Sum256(minpub) - return base64.RawURLEncoding.EncodeToString(sha[:]) -} - -func issueNonce(w http.ResponseWriter, r *http.Request) { - b := make([]byte, 16) - _, _ = rand.Read(b) - nonce := base64.RawURLEncoding.EncodeToString(b); - nonces[nonce] = time.Now().Unix() - - w.Header().Set("Replay-Nonce", nonce); -} - - -func requireNonce(next http.HandlerFunc) http.HandlerFunc { - return func (w http.ResponseWriter, r *http.Request) { - nonce := r.Header.Get("Replay-Nonce") - // TODO expire nonces every so often - t := nonces[nonce] - if 0 == t { - http.Error( - w, - `{ "error": "invalid or expired nonce", "error_code": "ENONCE" }`, - http.StatusBadRequest, - ) - return - } - - delete(nonces, nonce) - issueNonce(w, r) - - next(w, r); - } -} - -func getBaseURL(r *http.Request) string { - var scheme string - if nil != r.TLS || "https" == r.Header.Get("X-Forwarded-Proto") { - scheme = "https:" - } else { - scheme = "http:" - } - return fmt.Sprintf( - "%s//%s", - scheme, - r.Host, - ) -} diff --git a/mockid/mockid.go b/mockid/mockid.go new file mode 100644 index 0000000..1e0548d --- /dev/null +++ b/mockid/mockid.go @@ -0,0 +1,460 @@ +package mockid + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "math/big" + "net/http" + "net/url" + "path/filepath" + "strconv" + "strings" + "time" +) + +type PrivateJWK struct { + PublicJWK + D string `json:"d"` +} + +type PublicJWK struct { + Crv string `json:"crv"` + KeyID string `json:"kid,omitempty"` + Kty string `json:"kty,omitempty"` + X string `json:"x"` + Y string `json:"y"` +} + +var nonces map[string]int64 + +func init() { + nonces = make(map[string]int64) +} + +func Route(jwksPrefix string , priv *ecdsa.PrivateKey, jwk *PrivateJWK) { + pub := &priv.PublicKey + thumbprint := thumbprintKey(pub) + + http.HandleFunc("/api/new-nonce", func(w http.ResponseWriter, r *http.Request) { + baseURL := getBaseURL(r) + /* + res.statusCode = 200; + res.setHeader("Cache-Control", "max-age=0, no-cache, no-store"); + // TODO + //res.setHeader("Date", "Sun, 10 Mar 2019 08:04:45 GMT"); + // is this the expiration of the nonce itself? methinks maybe so + //res.setHeader("Expires", "Sun, 10 Mar 2019 08:04:45 GMT"); + // TODO use one of the registered domains + //var indexUrl = "https://acme-staging-v02.api.letsencrypt.org/index" + */ + //var port = (state.config.ipc && state.config.ipc.port || state._ipc.port || undefined); + //var indexUrl = "http://localhost:" + port + "/index"; + indexUrl := baseURL + "/index" + w.Header().Set("Link", "<"+indexUrl+">;rel=\"index\"") + w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store") + w.Header().Set("Pragma", "no-cache") + //res.setHeader("Strict-Transport-Security", "max-age=604800"); + + w.Header().Set("X-Frame-Options", "DENY") + issueNonce(w, r) + }) + + http.HandleFunc("/api/new-account", requireNonce(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Not Implemented", http.StatusNotImplemented) + })) + + http.HandleFunc("/api/jwks", func(w http.ResponseWriter, r *http.Request) { + log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path) + if "POST" != r.Method { + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + + tok := make(map[string]interface{}) + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&tok) + if nil != err { + http.Error(w, "Bad Request: invalid json", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // TODO better, JSON error messages + if _, ok := tok["d"]; ok { + http.Error(w, "Bad Request: private key", http.StatusBadRequest) + return + } + + kty, _ := tok["kty"].(string) + switch kty { + case "EC": + postEC(jwksPrefix, tok, w, r) + case "RSA": + postRSA(jwksPrefix, tok, w, r) + default: + http.Error(w, "Bad Request: only EC and RSA keys are supported", http.StatusBadRequest) + return + } + }) + + http.HandleFunc("/access_token", func(w http.ResponseWriter, r *http.Request) { + log.Printf("%s %s\n", r.Method, r.URL.Path) + _, _, token := GenToken(getBaseURL(r), priv, r.URL.Query()) + fmt.Fprintf(w, token) + }) + + http.HandleFunc("/authorization_header", func(w http.ResponseWriter, r *http.Request) { + log.Printf("%s %s\n", r.Method, r.URL.Path) + + var header string + headers, _ := r.URL.Query()["header"] + if 0 == len(headers) { + header = "Authorization" + } else { + header = headers[0] + } + + var prefix string + prefixes, _ := r.URL.Query()["prefix"] + if 0 == len(prefixes) { + prefix = "Bearer " + } else { + prefix = prefixes[0] + } + + _, _, token := GenToken(getBaseURL(r), priv, r.URL.Query()) + fmt.Fprintf(w, "%s: %s%s", header, prefix, token) + }) + + http.HandleFunc("/key.jwk.json", func(w http.ResponseWriter, r *http.Request) { + log.Printf("%s %s", r.Method, r.URL.Path) + fmt.Fprintf(w, `{ "kty": "EC" , "crv": %q , "d": %q , "x": %q , "y": %q , "ext": true , "key_ops": ["sign"] }`, jwk.Crv, jwk.D, jwk.X, jwk.Y) + }) + + http.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + baseURL := getBaseURL(r) + log.Printf("%s %s\n", r.Method, r.URL.Path) + fmt.Fprintf(w, `{ "issuer": "%s", "jwks_uri": "%s/.well-known/jwks.json" }`, baseURL, baseURL) + }) + + http.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { + log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path) + parts := strings.Split(r.Host, ".") + kid := parts[0] + + b, err := ioutil.ReadFile(filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json")) + if nil != err { + //http.Error(w, "Not Found", http.StatusNotFound) + jwkstr := fmt.Sprintf( + `{ "keys": [ { "kty": "EC" , "crv": %q , "x": %q , "y": %q , "kid": %q , "ext": true , "key_ops": ["verify"] , "exp": %s } ] }`, + jwk.Crv, jwk.X, jwk.Y, thumbprint, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10), + ) + fmt.Println(jwkstr) + fmt.Fprintf(w, jwkstr) + return + } + + tok := &PublicJWK{} + err = json.Unmarshal(b, tok) + if nil != err { + // TODO delete the bad file? + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + jwkstr := fmt.Sprintf( + `{ "keys": [ { "kty": "EC", "crv": %q, "x": %q, "y": %q, "kid": %q,`+ + ` "ext": true, "key_ops": ["verify"], "exp": %s } ] }`, + tok.Crv, tok.X, tok.Y, tok.KeyID, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10), + ) + fmt.Println(jwkstr) + fmt.Fprintf(w, jwkstr) + }) +} + +func parseExp(exp string) (int, error) { + if "" == exp { + exp = "15m" + } + mult := 1 + switch exp[len(exp)-1] { + case 'w': + mult *= 7 + fallthrough + case 'd': + mult *= 24 + fallthrough + case 'h': + mult *= 60 + fallthrough + case 'm': + mult *= 60 + fallthrough + case 's': + // no fallthrough + default: + // could be 'k' or 'z', but we assume its empty + exp += "s" + } + + num, err := strconv.Atoi(exp[:len(exp)-1]) + if nil != err { + return 0, err + } + return num * mult, nil +} + +func postEC(jwksPrefix string, tok map[string]interface{}, w http.ResponseWriter, r *http.Request) { + crv, ok := tok["crv"].(string) + if 5 != len(crv) || "P-" != crv[:2] { + http.Error(w, "Bad Request: bad curve", http.StatusBadRequest) + return + } + x, ok := tok["x"].(string) + if !ok { + http.Error(w, "Bad Request: missing 'x'", http.StatusBadRequest) + return + } + y, ok := tok["y"].(string) + if !ok { + http.Error(w, "Bad Request: missing 'y'", http.StatusBadRequest) + return + } + + thumbprintable := []byte( + fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, crv, x, y), + ) + alg := crv[2:] + + var thumb []byte + switch alg { + case "256": + hash := sha256.Sum256(thumbprintable) + thumb = hash[:] + case "384": + hash := sha512.Sum384(thumbprintable) + thumb = hash[:] + case "521": + fallthrough + case "512": + hash := sha512.Sum512(thumbprintable) + thumb = hash[:] + default: + http.Error(w, "Bad Request: bad key length or curve", http.StatusBadRequest) + return + } + + kid := base64.RawURLEncoding.EncodeToString(thumb) + if kid2, _ := tok["kid"].(string); "" != kid2 && kid != kid2 { + http.Error(w, "Bad Request: kid should be "+kid, http.StatusBadRequest) + return + } + + pub := []byte(fmt.Sprintf( + `{"crv":%q,"kid":%q,"kty":"EC","x":%q,"y":%q}`, crv, kid, x, y, + )) + + // TODO allow posting at the top-level? + // TODO support a group of keys by PPID + // (right now it's only by KID) + if !strings.HasPrefix(r.Host, strings.ToLower(kid)+".") { + http.Error(w, "Bad Request: prefix should be "+kid, http.StatusBadRequest) + return + } + + if err := ioutil.WriteFile( + filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json"), + pub, + 0644, + ); nil != err { + fmt.Println("can't write file") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + baseURL := getBaseURL(r) + w.Write([]byte(fmt.Sprintf( + `{ "iss":%q, "jwks_url":%q }`, baseURL+"/", baseURL+"/.well-known/jwks.json", + ))) +} + +func postRSA(jwksPrefix string, tok map[string]interface{}, w http.ResponseWriter, r *http.Request) { + e, ok := tok["e"].(string) + if !ok { + http.Error(w, "Bad Request: missing 'e'", http.StatusBadRequest) + return + } + n, ok := tok["n"].(string) + if !ok { + http.Error(w, "Bad Request: missing 'n'", http.StatusBadRequest) + return + } + + thumbprintable := []byte( + fmt.Sprintf(`{"e":%q,"kty":"RSA","n":%q}`, e, n), + ) + + var thumb []byte + // TODO handle bit lengths well + switch 3 * (len(n) / 4.0) { + case 256: + hash := sha256.Sum256(thumbprintable) + thumb = hash[:] + case 384: + hash := sha512.Sum384(thumbprintable) + thumb = hash[:] + case 512: + hash := sha512.Sum512(thumbprintable) + thumb = hash[:] + default: + http.Error(w, "Bad Request: only standard RSA key lengths (2048, 3072, 4096) are supported", http.StatusBadRequest) + return + } + + kid := base64.RawURLEncoding.EncodeToString(thumb) + if kid2, _ := tok["kid"].(string); "" != kid2 && kid != kid2 { + http.Error(w, "Bad Request: kid should be "+kid, http.StatusBadRequest) + return + } + + pub := []byte(fmt.Sprintf( + `{"e":%q,"kid":%q,"kty":"EC","n":%q}`, e, kid, n, + )) + + // TODO allow posting at the top-level? + // TODO support a group of keys by PPID + // (right now it's only by KID) + if !strings.HasPrefix(r.Host, strings.ToLower(kid)+".") { + http.Error(w, "Bad Request: prefix should be "+kid, http.StatusBadRequest) + return + } + + if err := ioutil.WriteFile( + filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json"), + pub, + 0644, + ); nil != err { + fmt.Println("can't write file") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + baseURL := getBaseURL(r) + w.Write([]byte(fmt.Sprintf( + `{ "iss":%q, "jwks_url":%q }`, baseURL+"/", baseURL+"/.well-known/jwks.json", + ))) +} + +func GenToken(host string, priv *ecdsa.PrivateKey, query url.Values) (string, string, string) { + thumbprint := thumbprintKey(&priv.PublicKey) + protected := fmt.Sprintf(`{"typ":"JWT","alg":"ES256","kid":"%s"}`, thumbprint) + protected64 := base64.RawURLEncoding.EncodeToString([]byte(protected)) + + exp, err := parseExp(query.Get("exp")) + if nil != err { + // cryptic error code + // TODO propagate error + exp = 422 + } + + payload := fmt.Sprintf( + `{"iss":"%s/","sub":"dummy","exp":%s}`, + host, strconv.FormatInt(time.Now().Add(time.Duration(exp)*time.Second).Unix(), 10), + ) + payload64 := base64.RawURLEncoding.EncodeToString([]byte(payload)) + + hash := sha256.Sum256([]byte(fmt.Sprintf(`%s.%s`, protected64, payload64))) + r, s, _ := ecdsa.Sign(rand.Reader, priv, hash[:]) + rb := r.Bytes() + for len(rb) < 32 { + rb = append([]byte{0}, rb...) + } + sb := s.Bytes() + for len(rb) < 32 { + sb = append([]byte{0}, sb...) + } + sig64 := base64.RawURLEncoding.EncodeToString(append(rb, sb...)) + token := fmt.Sprintf(`%s.%s.%s`, protected64, payload64, sig64) + return protected, payload, token +} + +func ParseKey(jwk *PrivateJWK) *ecdsa.PrivateKey { + xb, _ := base64.RawURLEncoding.DecodeString(jwk.X) + xi := &big.Int{} + xi.SetBytes(xb) + yb, _ := base64.RawURLEncoding.DecodeString(jwk.Y) + yi := &big.Int{} + yi.SetBytes(yb) + pub := &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: xi, + Y: yi, + } + + db, _ := base64.RawURLEncoding.DecodeString(jwk.D) + di := &big.Int{} + di.SetBytes(db) + priv := &ecdsa.PrivateKey{ + PublicKey: *pub, + D: di, + } + return priv +} + +func thumbprintKey(pub *ecdsa.PublicKey) string { + minpub := []byte(fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, "P-256", pub.X, pub.Y)) + sha := sha256.Sum256(minpub) + return base64.RawURLEncoding.EncodeToString(sha[:]) +} + +func issueNonce(w http.ResponseWriter, r *http.Request) { + b := make([]byte, 16) + _, _ = rand.Read(b) + nonce := base64.RawURLEncoding.EncodeToString(b) + nonces[nonce] = time.Now().Unix() + + w.Header().Set("Replay-Nonce", nonce) +} + +func requireNonce(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + nonce := r.Header.Get("Replay-Nonce") + // TODO expire nonces every so often + t := nonces[nonce] + if 0 == t { + http.Error( + w, + `{ "error": "invalid or expired nonce", "error_code": "ENONCE" }`, + http.StatusBadRequest, + ) + return + } + + delete(nonces, nonce) + issueNonce(w, r) + + next(w, r) + } +} + +func getBaseURL(r *http.Request) string { + var scheme string + if nil != r.TLS || "https" == r.Header.Get("X-Forwarded-Proto") { + scheme = "https:" + } else { + scheme = "http:" + } + return fmt.Sprintf( + "%s//%s", + scheme, + r.Host, + ) +} diff --git a/mockid/mockid_test.go b/mockid/mockid_test.go new file mode 100644 index 0000000..c62ee94 --- /dev/null +++ b/mockid/mockid_test.go @@ -0,0 +1,13 @@ +package mockid + +import ( + "testing" + "time" + + //keypairs "github.com/big-squid/go-keypairs" + //"github.com/big-squid/go-keypairs/keyfetch/uncached" +) + +func TestTest(t *testing.T) { + t.Fatal("no test") +}