package mockid import ( "crypto" "crypto/ecdsa" "crypto/rand" "crypto/rsa" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "io" "math/big" "net/http" "net/url" "os" "strconv" "sync" "time" "git.coolaj86.com/coolaj86/go-mockid/xkeypairs" "git.rootprojects.org/root/keypairs" //jwt "github.com/dgrijalva/jwt-go" ) // TestMain will overwrite this var rndsrc io.Reader = rand.Reader type PublicJWK struct { Crv string `json:"crv"` KeyID string `json:"kid,omitempty"` Kty string `json:"kty,omitempty"` X string `json:"x"` Y string `json:"y"` } type InspectableToken struct { Public keypairs.PublicKey `json:"jwk"` Protected map[string]interface{} `json:"protected"` Payload map[string]interface{} `json:"payload"` Signature string `json:"signature"` Verified bool `json:"verified"` Errors []string `json:"errors"` } func (t *InspectableToken) MarshalJSON() ([]byte, error) { pub := keypairs.MarshalJWKPublicKey(t.Public) header, _ := json.Marshal(t.Protected) payload, _ := json.Marshal(t.Payload) errs, _ := json.Marshal(t.Errors) return []byte(fmt.Sprintf( `{"jwk":%s,"protected":%s,"payload":%s,"signature":%q,"verified":%t,"errors":%s}`, pub, header, payload, t.Signature, t.Verified, errs, )), nil } var defaultFrom string var defaultReplyTo string //var nonces map[string]int64 //var nonCh chan string var nonces sync.Map var salt []byte func Init() { var err error salt64 := os.Getenv("SALT") salt, err = base64.RawURLEncoding.DecodeString(salt64) if len(salt64) < 22 || nil != err { panic("SALT must be set as 22+ character base64") } defaultFrom = os.Getenv("MAILER_FROM") defaultReplyTo = os.Getenv("MAILER_REPLY_TO") //nonces = make(map[string]int64) //nonCh = make(chan string) /* go func() { for { nonce := <- nonCh nonces[nonce] = time.Now().Unix() } }() */ } func GenToken(host string, privkey keypairs.PrivateKey, query url.Values) (string, string, string) { thumbprint := keypairs.ThumbprintPublicKey(keypairs.NewPublicKey(privkey.Public())) // TODO keypairs.Alg(key) alg := "ES256" switch privkey.(type) { case *rsa.PrivateKey: alg = "RS256" } protected := fmt.Sprintf(`{"typ":"JWT","alg":%q,"kid":"%s"}`, alg, thumbprint) protected64 := base64.RawURLEncoding.EncodeToString([]byte(protected)) exp, err := xkeypairs.ParseDuration(query.Get("exp")) if nil != err { // cryptic error code // TODO propagate error exp = 422 } payload := fmt.Sprintf( `{"iss":"%s/","sub":"dummy","exp":%s}`, host, strconv.FormatInt(time.Now().Add(time.Duration(exp)*time.Second).Unix(), 10), ) payload64 := base64.RawURLEncoding.EncodeToString([]byte(payload)) hash := sha256.Sum256([]byte(fmt.Sprintf(`%s.%s`, protected64, payload64))) sig := JOSESign(privkey, hash[:]) sig64 := base64.RawURLEncoding.EncodeToString(sig) token := fmt.Sprintf("%s.%s.%s\n", protected64, payload64, sig64) return protected, payload, token } func JOSESign(privkey keypairs.PrivateKey, hash []byte) []byte { var sig []byte switch k := privkey.(type) { case *rsa.PrivateKey: panic("TODO: implement rsa sign") case *ecdsa.PrivateKey: r, s, _ := ecdsa.Sign(rndsrc, k, hash[:]) rb := r.Bytes() fmt.Println("debug:") fmt.Println(r, s) for len(rb) < 32 { rb = append([]byte{0}, rb...) } sb := s.Bytes() for len(rb) < 32 { sb = append([]byte{0}, sb...) } sig = append(rb, sb...) } return sig } // TODO: move to keypairs func JOSEVerify(pubkey keypairs.PublicKey, hash []byte, sig []byte) bool { switch pub := pubkey.Key().(type) { case *rsa.PublicKey: // TODO keypairs.Size(key) to detect key size ? //alg := "SHA256" // TODO: this hasn't been tested yet if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, hash, sig); nil != err { return false } return true case *ecdsa.PublicKey: r := &big.Int{} r.SetBytes(sig[0:32]) s := &big.Int{} s.SetBytes(sig[32:]) fmt.Println("debug: sig len:", len(sig)) fmt.Println("debug: r, s:", r, s) return ecdsa.Verify(pub, hash, r, s) default: panic("impossible condition: non-rsa/non-ecdsa key") return false } } func issueNonce(w http.ResponseWriter, r *http.Request) { b := make([]byte, 16) _, _ = rand.Read(b) nonce := base64.RawURLEncoding.EncodeToString(b) //nonCh <- nonce nonces.Store(nonce, time.Now()) w.Header().Set("Replay-Nonce", nonce) } func requireNonce(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { nonce := r.Header.Get("Replay-Nonce") // TODO expire nonces every so often //t := nonces[nonce] var t time.Time tmp, ok := nonces.Load(nonce) if ok { t = tmp.(time.Time) } if !ok || time.Now().Sub(t) > 15*time.Minute { http.Error( w, `{ "error": "invalid or expired nonce", "error_code": "ENONCE" }`, http.StatusBadRequest, ) return } //delete(nonces, nonce) nonces.Delete(nonce) issueNonce(w, r) next(w, r) } }