package mockid import ( "crypto" "crypto/ecdsa" "crypto/rand" "crypto/rsa" "crypto/sha1" "crypto/sha256" "crypto/sha512" "encoding/base64" "encoding/json" "fmt" "io/ioutil" "log" "math/big" "net/http" "net/url" "os" "path/filepath" "strconv" "strings" "sync" "time" "git.rootprojects.org/root/keypairs" "git.rootprojects.org/root/keypairs/keyfetch" //jwt "github.com/dgrijalva/jwt-go" "github.com/google/uuid" ) type PublicJWK struct { Crv string `json:"crv"` KeyID string `json:"kid,omitempty"` Kty string `json:"kty,omitempty"` X string `json:"x"` Y string `json:"y"` } type InspectableToken struct { Public keypairs.PublicKey `json:"jwk"` Protected map[string]interface{} `json:"protected"` Payload map[string]interface{} `json:"payload"` Signature string `json:"signature"` Verified bool `json:"verified"` Errors []string `json:"errors"` } func (t *InspectableToken) MarshalJSON() ([]byte, error) { pub := keypairs.MarshalJWKPublicKey(t.Public) header, _ := json.Marshal(t.Protected) payload, _ := json.Marshal(t.Payload) errs, _ := json.Marshal(t.Errors) return []byte(fmt.Sprintf( `{"jwk":%s,"protected":%s,"payload":%s,"signature":%q,"verified":%t,"errors":%s}`, pub, header, payload, t.Signature, t.Verified, errs, )), nil } var defaultFrom string var defaultReplyTo string //var nonces map[string]int64 //var nonCh chan string var nonces sync.Map var salt []byte func init() { var err error salt64 := os.Getenv("SALT") salt, err = base64.RawURLEncoding.DecodeString(salt64) if len(salt64) < 22 || nil != err { panic("SALT must be set as 22+ character base64") } defaultFrom = os.Getenv("MAILER_FROM") defaultReplyTo = os.Getenv("MAILER_REPLY_TO") //nonces = make(map[string]int64) //nonCh = make(chan string) /* go func() { for { nonce := <- nonCh nonces[nonce] = time.Now().Unix() } }() */ } func Route(jwksPrefix string, privkey keypairs.PrivateKey) { // TODO get from main() tokPrefix := jwksPrefix pubkey := keypairs.NewPublicKey(privkey.Public()) http.HandleFunc("/api/new-nonce", func(w http.ResponseWriter, r *http.Request) { baseURL := getBaseURL(r) /* res.statusCode = 200; res.setHeader("Cache-Control", "max-age=0, no-cache, no-store"); // TODO //res.setHeader("Date", "Sun, 10 Mar 2019 08:04:45 GMT"); // is this the expiration of the nonce itself? methinks maybe so //res.setHeader("Expires", "Sun, 10 Mar 2019 08:04:45 GMT"); // TODO use one of the registered domains //var indexUrl = "https://acme-staging-v02.api.letsencrypt.org/index" */ //var port = (state.config.ipc && state.config.ipc.port || state._ipc.port || undefined); //var indexUrl = "http://localhost:" + port + "/index"; indexUrl := baseURL + "/index" w.Header().Set("Link", "<"+indexUrl+">;rel=\"index\"") w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store") w.Header().Set("Pragma", "no-cache") //res.setHeader("Strict-Transport-Security", "max-age=604800"); w.Header().Set("X-Frame-Options", "DENY") issueNonce(w, r) }) http.HandleFunc("/api/new-account", requireNonce(func(w http.ResponseWriter, r *http.Request) { // Try to decode the request body into the struct. If there is an error, // respond to the client with the error message and a 400 status code. data := map[string]string{} err := json.NewDecoder(r.Body).Decode(&data) if nil != err { http.Error(w, err.Error(), http.StatusBadRequest) return } // TODO check DNS for MX records parts := strings.Split(data["to"], ", <>\n\r\t") to := parts[0] if len(parts) > 1 || !strings.Contains(to, "@") { http.Error(w, "invalid email address", http.StatusBadRequest) return } token, err := uuid.NewRandom() if nil != err { // nothing else to do if we run out of random // or are on a platform that doesn't support random panic(fmt.Errorf("random bytes read failure: %w", err)) } token64 := base64.RawURLEncoding.EncodeToString([]byte(token[:])) // hash token to prevent fs read timing attacks hash := sha1.Sum(append(token[:], salt...)) tokname := base64.RawURLEncoding.EncodeToString(hash[:]) if err := ioutil.WriteFile( filepath.Join(tokPrefix, tokname+".tok.txt"), []byte(`{"comment":"I have no idea..."}`), os.FileMode(0600), ); nil != err { http.Error(w, "database connection failed when writing verification token", http.StatusInternalServerError) return } subject := "Verify New Account" // TODO go tpl // TODO determine OS and Browser from user agent baseURL := getBaseURL(r) text := fmt.Sprintf( "It looks like you just tried to register a new Pocket ID account.\n\n Verify account: %s/verify/%s\n\nNot you? Just ignore this message.", baseURL, token64, ) _, err = SendSimpleMessage(to, defaultFrom, subject, text, defaultReplyTo) if nil != err { // TODO neuter mailgun output http.Error(w, err.Error(), http.StatusBadRequest) return } fmt.Fprintf(w, `{ "success": true, "error": "" }%s`, "\n") })) // TODO use chi http.HandleFunc("/verify/", requireNonce(func(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.Path, "/") if 3 != len(parts) { http.Error(w, "invalid url path", http.StatusBadRequest) return } token64 := parts[2] token, err := base64.RawURLEncoding.DecodeString(token64) if err != nil || 0 == len(token) { http.Error(w, "invalid url path", http.StatusBadRequest) return } // hash token to prevent fs read timing attacks hash := sha1.Sum(append(token, salt...)) tokname := base64.RawURLEncoding.EncodeToString(hash[:]) tokfile := filepath.Join(tokPrefix, tokname+".tok.txt") _, err = ioutil.ReadFile(tokfile) if nil != err { http.Error(w, "database connection failed when reading verification token", http.StatusInternalServerError) return } os.Remove(tokfile) fmt.Fprintf(w, `{ "success": true, "error": "" }%s`, "\n") })) http.HandleFunc("/api/jwks", func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path) if "POST" != r.Method { http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) return } tok := make(map[string]interface{}) decoder := json.NewDecoder(r.Body) err := decoder.Decode(&tok) if nil != err { http.Error(w, "Bad Request: invalid json", http.StatusBadRequest) return } defer r.Body.Close() // TODO better, JSON error messages if _, ok := tok["d"]; ok { http.Error(w, "Bad Request: private key", http.StatusBadRequest) return } kty, _ := tok["kty"].(string) switch kty { case "EC": postEC(jwksPrefix, tok, w, r) case "RSA": postRSA(jwksPrefix, tok, w, r) default: http.Error(w, "Bad Request: only EC and RSA keys are supported", http.StatusBadRequest) return } }) http.HandleFunc("/access_token", func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s\n", r.Method, r.URL.Path) _, _, token := GenToken(getBaseURL(r), privkey, r.URL.Query()) fmt.Fprintf(w, token) }) http.HandleFunc("/inspect_token", func(w http.ResponseWriter, r *http.Request) { token := r.Header.Get("Authorization") log.Printf("%s %s %s\n", r.Method, r.URL.Path, token) if "" == token { token = r.URL.Query().Get("access_token") if "" == token { http.Error(w, "Bad Format: missing Authorization header and 'access_token' query", http.StatusBadRequest) return } } else { parts := strings.Split(token, " ") if 2 != len(parts) { http.Error(w, "Bad Format: expected Authorization header to be in the format of 'Bearer '", http.StatusBadRequest) return } token = parts[1] } parts := strings.Split(token, ".") if 3 != len(parts) { http.Error(w, "Bad Format: token should be in the format of ..", http.StatusBadRequest) return } protected64 := parts[0] payload64 := parts[1] signature64 := parts[2] protectedB, err := base64.RawURLEncoding.DecodeString(protected64) if nil != err { http.Error(w, "Bad Format: token's header should be URL-safe base64 encoded", http.StatusBadRequest) return } payloadB, err := base64.RawURLEncoding.DecodeString(payload64) if nil != err { http.Error(w, "Bad Format: token's payload should be URL-safe base64 encoded", http.StatusBadRequest) return } // TODO verify signature sig, err := base64.RawURLEncoding.DecodeString(signature64) if nil != err { http.Error(w, "Bad Format: token's signature should be URL-safe base64 encoded", http.StatusBadRequest) return } errors := []string{} protected := map[string]interface{}{} err = json.Unmarshal(protectedB, &protected) if nil != err { http.Error(w, "Bad Format: token's header should be URL-safe base64-encoded JSON", http.StatusBadRequest) return } kid, kidOK := protected["kid"].(string) // TODO parse jwkM _, jwkOK := protected["jwk"] if !kidOK && !jwkOK { errors = append(errors, "must have either header.kid or header.jwk") } data := map[string]interface{}{} err = json.Unmarshal(payloadB, &data) if nil != err { http.Error(w, "Bad Format: token's payload should be URL-safe base64-encoded JSON", http.StatusBadRequest) return } iss, issOK := data["iss"].(string) if !jwkOK && !issOK { errors = append(errors, "payload.iss must exist to complement header.kid") } pub, err := keyfetch.OIDCJWK(kid, iss) if nil != err { fmt.Println("couldn't fetch pub key:") fmt.Println(err) } fmt.Println("fetched pub key:") fmt.Println(pub) hash := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", protected64, payload64))) verified := JOSEVerify(pub, hash[:], sig) inspected := &InspectableToken{ Public: pub, Protected: protected, Payload: data, Signature: signature64, Verified: verified, Errors: errors, } tokenB, _ := json.MarshalIndent(inspected, "", " ") if nil != err { http.Error(w, "Bad Format: malformed token, or malformed jwk at issuer url", http.StatusInternalServerError) return } fmt.Fprintf(w, string(tokenB)) }) http.HandleFunc("/authorization_header", func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s\n", r.Method, r.URL.Path) var header string headers, _ := r.URL.Query()["header"] if 0 == len(headers) { header = "Authorization" } else { header = headers[0] } var prefix string prefixes, _ := r.URL.Query()["prefix"] if 0 == len(prefixes) { prefix = "Bearer " } else { prefix = prefixes[0] } _, _, token := GenToken(getBaseURL(r), privkey, r.URL.Query()) fmt.Fprintf(w, "%s: %s%s", header, prefix, token) }) http.HandleFunc("/key.jwk.json", func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s", r.Method, r.URL.Path) jwk := string(MarshalJWKPrivateKey(privkey)) jwk = strings.Replace(jwk, `{"`, `{ "`, 1) jwk = strings.Replace(jwk, `",`, `", `, -1) jwk = jwk[0 : len(jwk)-1] jwk = jwk + `, "ext": true , "key_ops": ["sign"] }` // `{ "kty": "EC" , "crv": %q , "d": %q , "x": %q , "y": %q }`, jwk.Crv, jwk.D, jwk.X, jwk.Y fmt.Fprintf(w, jwk) }) http.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { baseURL := getBaseURL(r) log.Printf("%s %s\n", r.Method, r.URL.Path) fmt.Fprintf(w, `{ "issuer": "%s", "jwks_uri": "%s/.well-known/jwks.json" }`, baseURL, baseURL) }) http.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path) parts := strings.Split(r.Host, ".") kid := parts[0] b, err := ioutil.ReadFile(filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json")) if nil != err { //http.Error(w, "Not Found", http.StatusNotFound) exp := strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10) jwk := string(keypairs.MarshalJWKPublicKey(pubkey)) jwk = strings.Replace(jwk, `{"`, `{ "`, 1) jwk = strings.Replace(jwk, `",`, `" ,`, -1) jwk = jwk[0 : len(jwk)-1] jwk = jwk + fmt.Sprintf(`, "ext": true , "key_ops": ["verify"], "exp": %s }`, exp) // { "kty": "EC" , "crv": %q , "x": %q , "y": %q , "kid": %q , "ext": true , "key_ops": ["verify"] , "exp": %s } jwkstr := fmt.Sprintf(`{ "keys": [ %s ] }`, jwk) fmt.Println(jwkstr) fmt.Fprintf(w, jwkstr) return } tok := &PublicJWK{} err = json.Unmarshal(b, tok) if nil != err { // TODO delete the bad file? http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } jwkstr := fmt.Sprintf( `{ "keys": [ { "kty": "EC", "crv": %q, "x": %q, "y": %q, "kid": %q,`+ ` "ext": true, "key_ops": ["verify"], "exp": %s } ] }`, tok.Crv, tok.X, tok.Y, tok.KeyID, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10), ) fmt.Println(jwkstr) fmt.Fprintf(w, jwkstr) }) } func parseExp(exp string) (int, error) { if "" == exp { exp = "15m" } mult := 1 switch exp[len(exp)-1] { case 'w': mult *= 7 fallthrough case 'd': mult *= 24 fallthrough case 'h': mult *= 60 fallthrough case 'm': mult *= 60 fallthrough case 's': // no fallthrough default: // could be 'k' or 'z', but we assume its empty exp += "s" } num, err := strconv.Atoi(exp[:len(exp)-1]) if nil != err { return 0, err } return num * mult, nil } func postEC(jwksPrefix string, tok map[string]interface{}, w http.ResponseWriter, r *http.Request) { crv, ok := tok["crv"].(string) if 5 != len(crv) || "P-" != crv[:2] { http.Error(w, "Bad Request: bad curve", http.StatusBadRequest) return } x, ok := tok["x"].(string) if !ok { http.Error(w, "Bad Request: missing 'x'", http.StatusBadRequest) return } y, ok := tok["y"].(string) if !ok { http.Error(w, "Bad Request: missing 'y'", http.StatusBadRequest) return } thumbprintable := []byte( fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, crv, x, y), ) alg := crv[2:] var thumb []byte switch alg { case "256": hash := sha256.Sum256(thumbprintable) thumb = hash[:] case "384": hash := sha512.Sum384(thumbprintable) thumb = hash[:] case "521": fallthrough case "512": hash := sha512.Sum512(thumbprintable) thumb = hash[:] default: http.Error(w, "Bad Request: bad key length or curve", http.StatusBadRequest) return } kid := base64.RawURLEncoding.EncodeToString(thumb) if kid2, _ := tok["kid"].(string); "" != kid2 && kid != kid2 { http.Error(w, "Bad Request: kid should be "+kid, http.StatusBadRequest) return } pub := []byte(fmt.Sprintf( `{"crv":%q,"kid":%q,"kty":"EC","x":%q,"y":%q}`, crv, kid, x, y, )) // TODO allow posting at the top-level? // TODO support a group of keys by PPID // (right now it's only by KID) if !strings.HasPrefix(r.Host, strings.ToLower(kid)+".") { http.Error(w, "Bad Request: prefix should be "+kid, http.StatusBadRequest) return } if err := ioutil.WriteFile( filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json"), pub, 0644, ); nil != err { fmt.Println("can't write file") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } baseURL := getBaseURL(r) w.Write([]byte(fmt.Sprintf( `{ "iss":%q, "jwks_url":%q }`, baseURL+"/", baseURL+"/.well-known/jwks.json", ))) } func postRSA(jwksPrefix string, tok map[string]interface{}, w http.ResponseWriter, r *http.Request) { e, ok := tok["e"].(string) if !ok { http.Error(w, "Bad Request: missing 'e'", http.StatusBadRequest) return } n, ok := tok["n"].(string) if !ok { http.Error(w, "Bad Request: missing 'n'", http.StatusBadRequest) return } thumbprintable := []byte( fmt.Sprintf(`{"e":%q,"kty":"RSA","n":%q}`, e, n), ) var thumb []byte // TODO handle bit lengths well switch 3 * (len(n) / 4.0) { case 256: hash := sha256.Sum256(thumbprintable) thumb = hash[:] case 384: hash := sha512.Sum384(thumbprintable) thumb = hash[:] case 512: hash := sha512.Sum512(thumbprintable) thumb = hash[:] default: http.Error(w, "Bad Request: only standard RSA key lengths (2048, 3072, 4096) are supported", http.StatusBadRequest) return } kid := base64.RawURLEncoding.EncodeToString(thumb) if kid2, _ := tok["kid"].(string); "" != kid2 && kid != kid2 { http.Error(w, "Bad Request: kid should be "+kid, http.StatusBadRequest) return } pub := []byte(fmt.Sprintf( `{"e":%q,"kid":%q,"kty":"EC","n":%q}`, e, kid, n, )) // TODO allow posting at the top-level? // TODO support a group of keys by PPID // (right now it's only by KID) if !strings.HasPrefix(r.Host, strings.ToLower(kid)+".") { http.Error(w, "Bad Request: prefix should be "+kid, http.StatusBadRequest) return } if err := ioutil.WriteFile( filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json"), pub, 0644, ); nil != err { fmt.Println("can't write file") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } baseURL := getBaseURL(r) w.Write([]byte(fmt.Sprintf( `{ "iss":%q, "jwks_url":%q }`, baseURL+"/", baseURL+"/.well-known/jwks.json", ))) } func GenToken(host string, privkey keypairs.PrivateKey, query url.Values) (string, string, string) { thumbprint := keypairs.ThumbprintPublicKey(keypairs.NewPublicKey(privkey.Public())) // TODO keypairs.Alg(key) alg := "ES256" switch privkey.(type) { case *rsa.PrivateKey: alg = "RS256" } protected := fmt.Sprintf(`{"typ":"JWT","alg":%q,"kid":"%s"}`, alg, thumbprint) protected64 := base64.RawURLEncoding.EncodeToString([]byte(protected)) exp, err := parseExp(query.Get("exp")) if nil != err { // cryptic error code // TODO propagate error exp = 422 } payload := fmt.Sprintf( `{"iss":"%s/","sub":"dummy","exp":%s}`, host, strconv.FormatInt(time.Now().Add(time.Duration(exp)*time.Second).Unix(), 10), ) payload64 := base64.RawURLEncoding.EncodeToString([]byte(payload)) hash := sha256.Sum256([]byte(fmt.Sprintf(`%s.%s`, protected64, payload64))) sig := JOSESign(privkey, hash[:]) sig64 := base64.RawURLEncoding.EncodeToString(sig) token := fmt.Sprintf("%s.%s.%s\n", protected64, payload64, sig64) return protected, payload, token } // TODO: move to keypairs func JOSEVerify(pubkey keypairs.PublicKey, hash []byte, sig []byte) bool { var verified bool switch pub := pubkey.Key().(type) { case *rsa.PublicKey: // TODO keypairs.Size(key) to detect key size ? //alg := "SHA256" // TODO: this hasn't been tested yet if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, hash, sig); nil != err { verified = true } case *ecdsa.PublicKey: r := &big.Int{} r.SetBytes(sig[0:32]) s := &big.Int{} s.SetBytes(sig[32:]) fmt.Println("debug: sig len:", len(sig)) fmt.Println("debug: r, s:", r, s) verified = ecdsa.Verify(pub, hash, r, s) default: panic("impossible condition: non-rsa/non-ecdsa key") } return verified } func JOSESign(privkey keypairs.PrivateKey, hash []byte) []byte { var sig []byte switch k := privkey.(type) { case *rsa.PrivateKey: panic("TODO: implement rsa sign") case *ecdsa.PrivateKey: r, s, _ := ecdsa.Sign(rand.Reader, k, hash[:]) rb := r.Bytes() fmt.Println("debug:") fmt.Println(r, s) for len(rb) < 32 { rb = append([]byte{0}, rb...) } sb := s.Bytes() for len(rb) < 32 { sb = append([]byte{0}, sb...) } sig = append(rb, sb...) } return sig } func issueNonce(w http.ResponseWriter, r *http.Request) { b := make([]byte, 16) _, _ = rand.Read(b) nonce := base64.RawURLEncoding.EncodeToString(b) //nonCh <- nonce nonces.Store(nonce, time.Now()) w.Header().Set("Replay-Nonce", nonce) } func requireNonce(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { nonce := r.Header.Get("Replay-Nonce") // TODO expire nonces every so often //t := nonces[nonce] var t time.Time tmp, ok := nonces.Load(nonce) if ok { t = tmp.(time.Time) } if !ok || time.Now().Sub(t) > 15*time.Minute { http.Error( w, `{ "error": "invalid or expired nonce", "error_code": "ENONCE" }`, http.StatusBadRequest, ) return } //delete(nonces, nonce) nonces.Delete(nonce) issueNonce(w, r) next(w, r) } } func getBaseURL(r *http.Request) string { var scheme string if nil != r.TLS || "https" == r.Header.Get("X-Forwarded-Proto") { scheme = "https:" } else { scheme = "http:" } return fmt.Sprintf( "%s//%s", scheme, r.Host, ) } // MarshalJWKPrivateKey outputs the given private key as JWK func MarshalJWKPrivateKey(privkey keypairs.PrivateKey) []byte { // thumbprint keys are alphabetically sorted and only include the necessary public parts switch k := privkey.(type) { case *rsa.PrivateKey: return MarshalRSAPrivateKey(k) case *ecdsa.PrivateKey: return MarshalECPrivateKey(k) default: // this is unreachable because we know the types that we pass in log.Printf("keytype: %t, %+v\n", privkey, privkey) panic(keypairs.ErrInvalidPublicKey) } } // MarshalECPrivateKey will output the given private key as JWK func MarshalECPrivateKey(k *ecdsa.PrivateKey) []byte { crv := k.Curve.Params().Name d := base64.RawURLEncoding.EncodeToString(k.D.Bytes()) x := base64.RawURLEncoding.EncodeToString(k.X.Bytes()) y := base64.RawURLEncoding.EncodeToString(k.Y.Bytes()) return []byte(fmt.Sprintf( `{"crv":%q,"d":%q,"kty":"EC","x":%q,"y":%q}`, crv, d, x, y, )) } // MarshalRSAPrivateKey will output the given private key as JWK func MarshalRSAPrivateKey(pk *rsa.PrivateKey) []byte { e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pk.E)).Bytes()) n := base64.RawURLEncoding.EncodeToString(pk.N.Bytes()) d := base64.RawURLEncoding.EncodeToString(pk.D.Bytes()) p := base64.RawURLEncoding.EncodeToString(pk.Primes[0].Bytes()) q := base64.RawURLEncoding.EncodeToString(pk.Primes[1].Bytes()) dp := base64.RawURLEncoding.EncodeToString(pk.Precomputed.Dp.Bytes()) dq := base64.RawURLEncoding.EncodeToString(pk.Precomputed.Dq.Bytes()) qi := base64.RawURLEncoding.EncodeToString(pk.Precomputed.Qinv.Bytes()) return []byte(fmt.Sprintf( `{"d":%q,"dp":%q,"dq":%q,"e":%q,"kty":"RSA","n":%q,"p":%q,"q":%q,"qi":%q}`, d, dp, dq, e, n, p, q, qi, )) }