package main import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/sha256" "crypto/sha512" "encoding/base64" "encoding/json" "flag" "fmt" "io/ioutil" "log" "math/big" "net/http" "net/url" "os" "path/filepath" "strconv" "strings" "time" ) type PrivateJWK struct { PublicJWK D string `json:"d"` } type PublicJWK struct { Crv string `json:"crv"` KeyID string `json:"kid,omitempty"` Kty string `json:"kty,omitempty"` X string `json:"x"` Y string `json:"y"` } var jwksPrefix string func main() { done := make(chan bool) var port int var host string jwkm := map[string]string{ "crv": "P-256", "d": "GYAwlBHc2mPsj1lp315HbYOmKNJ7esmO3JAkZVn9nJs", "x": "ToL2HppsTESXQKvp7ED6NMgV4YnwbMeONexNry3KDNQ", "y": "Tt6Q3rxU37KAinUV9PLMlwosNy1t3Bf2VDg5q955AGc", } jwk := &PrivateJWK{ PublicJWK: PublicJWK{ Crv: jwkm["crv"], X: jwkm["x"], Y: jwkm["y"], }, D: jwkm["d"], } priv := parseKey(jwk) pub := &priv.PublicKey thumbprint := thumbprintKey(pub) portFlag := flag.Int("port", 0, "Port on which the HTTP server should run") urlFlag := flag.String("url", "", "Outward-facing address, such as https://example.com") prefixFlag := flag.String("jwkspath", "", "The path to the JWKs storage directory") flag.Parse() if nil != portFlag && *portFlag > 0 { port = *portFlag } else { portStr := os.Getenv("PORT") port, _ = strconv.Atoi(portStr) } if port < 1 { fmt.Fprintf(os.Stderr, "You must specify --port or PORT\n") os.Exit(1) } if nil != urlFlag && "" != *urlFlag { host = *urlFlag } else { host = "http://localhost:" + strconv.Itoa(port) } if nil != prefixFlag && "" != *prefixFlag { jwksPrefix = *prefixFlag } else { jwksPrefix = "public-jwks" } err := os.MkdirAll(jwksPrefix, 0755) if nil != err { fmt.Fprintf(os.Stderr, "couldn't write %q: %s", jwksPrefix, err) os.Exit(1) } http.HandleFunc("/api/jwks", func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path) if "POST" != r.Method { http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) return } tok := make(map[string]interface{}) decoder := json.NewDecoder(r.Body) err := decoder.Decode(&tok) if nil != err { http.Error(w, "Bad Request: invalid json", http.StatusBadRequest) return } defer r.Body.Close() // TODO better, JSON error messages if _, ok := tok["d"]; ok { http.Error(w, "Bad Request: private key", http.StatusBadRequest) return } kty, _ := tok["kty"].(string) if "EC" != kty { http.Error(w, "Bad Request: only EC keys are supported", http.StatusBadRequest) return } crv, ok := tok["crv"].(string) if 5 != len(crv) || "P-" != crv[:2] { http.Error(w, "Bad Request: bad curve", http.StatusBadRequest) return } x, ok := tok["x"].(string) if !ok { http.Error(w, "Bad Request: missing 'x'", http.StatusBadRequest) return } y, ok := tok["y"].(string) if !ok { http.Error(w, "Bad Request: missing 'y'", http.StatusBadRequest) return } // TODO RSA thumbprintable := []byte( fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, crv, x, y), ) var thumb []byte switch crv[2:] { case "256": hash := sha256.Sum256(thumbprintable) thumb = hash[:] case "384": hash := sha512.Sum384(thumbprintable) thumb = hash[:] case "521": hash := sha512.Sum512(thumbprintable) thumb = hash[:] default: http.Error(w, "Bad Request: bad curve", http.StatusBadRequest) return } kid := base64.RawURLEncoding.EncodeToString(thumb) if kid2, _ := tok["kid"].(string); "" != kid2 && kid != kid2 { http.Error(w, "Bad Request: kid should be "+kid, http.StatusBadRequest) return } // TODO allow posting at the top-level? // TODO support a group of keys by PPID // (right now it's only by KID) if !strings.HasPrefix(r.Host, strings.ToLower(kid)+".") { http.Error(w, "Bad Request: prefix should be "+kid, http.StatusBadRequest) return } pub := []byte(fmt.Sprintf( `{"crv":%q,"kid":%q,"kty":"EC","x":%q,"y":%q}`, crv, kid, x, y, )) err = ioutil.WriteFile( filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json"), pub, 0644, ) if nil != err { fmt.Println("can't write file") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } var scheme string if nil != r.TLS || "https" == r.Header.Get("X-Forwarded-Proto") { scheme = "https://" } else { scheme = "http://" } w.Write([]byte(fmt.Sprintf( `{ "iss":%q, "jwks_url":%q }`, scheme+r.Host+"/", scheme+r.Host+"/.well-known/jwks.json", ))) }) http.HandleFunc("/access_token", func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s\n", r.Method, r.URL.Path) var scheme string if nil != r.TLS || "https" == r.Header.Get("X-Forwarded-Proto") { scheme = "https://" } else { scheme = "http://" } _, _, token := genToken(scheme+r.Host, priv, r.URL.Query()) fmt.Fprintf(w, token) }) http.HandleFunc("/authorization_header", func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s\n", r.Method, r.URL.Path) var scheme string if nil != r.TLS || "https" == r.Header.Get("X-Forwarded-Proto") { scheme = "https://" } else { scheme = "http://" } var header string headers, _ := r.URL.Query()["header"] if 0 == len(headers) { header = "Authorization" } else { header = headers[0] } var prefix string prefixes, _ := r.URL.Query()["prefix"] if 0 == len(prefixes) { prefix = "Bearer " } else { prefix = prefixes[0] } _, _, token := genToken(scheme+r.Host, priv, r.URL.Query()) fmt.Fprintf(w, "%s: %s%s", header, prefix, token) }) http.HandleFunc("/key.jwk.json", func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s", r.Method, r.URL.Path) fmt.Fprintf(w, `{ "kty": "EC" , "crv": %q , "d": %q , "x": %q , "y": %q , "ext": true , "key_ops": ["sign"] }`, jwk.Crv, jwk.D, jwk.X, jwk.Y) }) http.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { var scheme string if nil != r.TLS || "https" == r.Header.Get("X-Forwarded-Proto") { scheme = "https://" } else { scheme = "http://" } log.Printf("%s %s\n", r.Method, r.URL.Path) fmt.Fprintf(w, `{ "issuer": "%s", "jwks_uri": "%s/.well-known/jwks.json" }`, scheme+r.Host, scheme+r.Host) }) http.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path) parts := strings.Split(r.Host, ".") kid := parts[0] b, err := ioutil.ReadFile(filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json")) if nil != err { //http.Error(w, "Not Found", http.StatusNotFound) jwkstr := fmt.Sprintf( `{ "keys": [ { "kty": "EC" , "crv": %q , "x": %q , "y": %q , "kid": %q , "ext": true , "key_ops": ["verify"] , "exp": %s } ] }`, jwk.Crv, jwk.X, jwk.Y, thumbprint, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10), ) fmt.Println(jwkstr) fmt.Fprintf(w, jwkstr) return } tok := &PublicJWK{} err = json.Unmarshal(b, tok) if nil != err { // TODO delete the bad file? http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } jwkstr := fmt.Sprintf( `{ "keys": [ { "kty": "EC", "crv": %q, "x": %q, "y": %q, "kid": %q,`+ ` "ext": true, "key_ops": ["verify"], "exp": %s } ] }`, tok.Crv, tok.X, tok.Y, tok.KeyID, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10), ) fmt.Println(jwkstr) fmt.Fprintf(w, jwkstr) }) fs := http.FileServer(http.Dir("public")) http.Handle("/", fs) /* http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { log.Printf(r.Method, r.URL.Path) http.Error(w, "Not Found", http.StatusNotFound) }) */ fmt.Printf("Serving on port %d\n", port) go func() { log.Fatal(http.ListenAndServe(":"+strconv.Itoa(port), nil)) done <- true }() b, _ := json.Marshal(jwk) fmt.Printf("Private Key:\n\t%s\n", string(b)) b, _ = json.Marshal(jwk.PublicJWK) fmt.Printf("Public Key:\n\t%s\n", string(b)) protected, payload, token := genToken(host, priv, url.Values{}) fmt.Printf("Protected (Header):\n\t%s\n", protected) fmt.Printf("Payload (Claims):\n\t%s\n", payload) fmt.Printf("Access Token:\n\t%s\n", token) <-done } func parseExp(exp string) (int, error) { if "" == exp { exp = "15m" } mult := 1 switch exp[len(exp)-1] { case 'w': mult *= 7 fallthrough case 'd': mult *= 24 fallthrough case 'h': mult *= 60 fallthrough case 'm': mult *= 60 fallthrough case 's': // no fallthrough default: // could be 'k' or 'z', but we assume its empty exp += "s" } num, err := strconv.Atoi(exp[:len(exp)-1]) if nil != err { return 0, err } return num * mult, nil } func genToken(host string, priv *ecdsa.PrivateKey, query url.Values) (string, string, string) { thumbprint := thumbprintKey(&priv.PublicKey) protected := fmt.Sprintf(`{"typ":"JWT","alg":"ES256","kid":"%s"}`, thumbprint) protected64 := base64.RawURLEncoding.EncodeToString([]byte(protected)) exp, err := parseExp(query.Get("exp")) if nil != err { // cryptic error code // TODO propagate error exp = 422 } payload := fmt.Sprintf( `{"iss":"%s/","sub":"dummy","exp":%s}`, host, strconv.FormatInt(time.Now().Add(time.Duration(exp)*time.Second).Unix(), 10), ) payload64 := base64.RawURLEncoding.EncodeToString([]byte(payload)) hash := sha256.Sum256([]byte(fmt.Sprintf(`%s.%s`, protected64, payload64))) r, s, _ := ecdsa.Sign(rand.Reader, priv, hash[:]) rb := r.Bytes() for len(rb) < 32 { rb = append([]byte{0}, rb...) } sb := s.Bytes() for len(rb) < 32 { sb = append([]byte{0}, sb...) } sig64 := base64.RawURLEncoding.EncodeToString(append(rb, sb...)) token := fmt.Sprintf(`%s.%s.%s`, protected64, payload64, sig64) return protected, payload, token } func parseKey(jwk *PrivateJWK) *ecdsa.PrivateKey { xb, _ := base64.RawURLEncoding.DecodeString(jwk.X) xi := &big.Int{} xi.SetBytes(xb) yb, _ := base64.RawURLEncoding.DecodeString(jwk.Y) yi := &big.Int{} yi.SetBytes(yb) pub := &ecdsa.PublicKey{ Curve: elliptic.P256(), X: xi, Y: yi, } db, _ := base64.RawURLEncoding.DecodeString(jwk.D) di := &big.Int{} di.SetBytes(db) priv := &ecdsa.PrivateKey{ PublicKey: *pub, D: di, } return priv } func thumbprintKey(pub *ecdsa.PublicKey) string { minpub := []byte(fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, "P-256", pub.X, pub.Y)) sha := sha256.Sum256(minpub) return base64.RawURLEncoding.EncodeToString(sha[:]) }