From a1b4ad1202c6880fd28142b2b342a7d12ef71e8e Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Tue, 4 Aug 2020 07:09:43 +0000 Subject: [PATCH] can now self-sign JWS and JWT --- mockid/api/common.go | 7 ++ mockid/api/sign.go | 57 ++++++++++++++ mockid/mockid.go | 46 ++++++------ mockid/mockid_test.go | 70 +++++++++++++++++ mockid/route.go | 4 + xkeypairs/parse.go | 3 + xkeypairs/sign.go | 170 ++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 334 insertions(+), 23 deletions(-) create mode 100644 mockid/api/sign.go create mode 100644 xkeypairs/sign.go diff --git a/mockid/api/common.go b/mockid/api/common.go index b5ef199..c5c5f3c 100644 --- a/mockid/api/common.go +++ b/mockid/api/common.go @@ -13,12 +13,16 @@ import ( "net/http" ) +type Object = map[string]interface{} + // options are the things that we may need to know about a request to fulfill it properly type options struct { Key string `json:"key"` KeyType string `json:"kty"` Seed int64 `json:"-"` SeedStr string `json:"seed"` + Claims Object `json:"claims"` + Header Object `json:"header"` } // this shananigans is only for testing and debug API stuff @@ -55,6 +59,9 @@ func getOpts(r *http.Request) (*options, error) { Key: key, } + opts.Claims, _ = tok["claims"].(Object) + opts.Header, _ = tok["header"].(Object) + var n int if 0 != seed { n = opts.nextReader().(*mathrand.Rand).Intn(2) diff --git a/mockid/api/sign.go b/mockid/api/sign.go new file mode 100644 index 0000000..36753bf --- /dev/null +++ b/mockid/api/sign.go @@ -0,0 +1,57 @@ +package api + +import ( + "encoding/json" + "net/http" + + "git.coolaj86.com/coolaj86/go-mockid/xkeypairs" +) + +// SignJWS will create an uncompressed JWT with the given payload +func SignJWS(w http.ResponseWriter, r *http.Request) { + sign(w, r, false) +} + +// SignJWT will create an compressed JWS (JWT) with the given payload +func SignJWT(w http.ResponseWriter, r *http.Request) { + sign(w, r, true) +} + +func sign(w http.ResponseWriter, r *http.Request, jwt bool) { + if "POST" != r.Method { + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + + opts, err := getOpts(r) + if nil != err { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + privkey, err := getPrivKey(opts) + if nil != err { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + header := opts.Header + if 0 != opts.Seed { + header["_seed"] = opts.Seed + } + + jws, err := xkeypairs.SignClaims(privkey, header, opts.Claims) + if nil != err { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var b []byte + if jwt { + s := xkeypairs.JWSToJWT(jws) + w.Write(append([]byte(s), '\n')) + return + } + b, _ = json.Marshal(jws) + w.Write(append(b, '\n')) +} diff --git a/mockid/mockid.go b/mockid/mockid.go index 4ff4b96..338bb80 100644 --- a/mockid/mockid.go +++ b/mockid/mockid.go @@ -115,6 +115,29 @@ func GenToken(host string, privkey keypairs.PrivateKey, query url.Values) (strin return protected, payload, token } +func JOSESign(privkey keypairs.PrivateKey, hash []byte) []byte { + var sig []byte + + switch k := privkey.(type) { + case *rsa.PrivateKey: + panic("TODO: implement rsa sign") + case *ecdsa.PrivateKey: + r, s, _ := ecdsa.Sign(rndsrc, k, hash[:]) + rb := r.Bytes() + fmt.Println("debug:") + fmt.Println(r, s) + for len(rb) < 32 { + rb = append([]byte{0}, rb...) + } + sb := s.Bytes() + for len(rb) < 32 { + sb = append([]byte{0}, sb...) + } + sig = append(rb, sb...) + } + return sig +} + // TODO: move to keypairs func JOSEVerify(pubkey keypairs.PublicKey, hash []byte, sig []byte) bool { @@ -143,29 +166,6 @@ func JOSEVerify(pubkey keypairs.PublicKey, hash []byte, sig []byte) bool { return verified } -func JOSESign(privkey keypairs.PrivateKey, hash []byte) []byte { - var sig []byte - - switch k := privkey.(type) { - case *rsa.PrivateKey: - panic("TODO: implement rsa sign") - case *ecdsa.PrivateKey: - r, s, _ := ecdsa.Sign(rndsrc, k, hash[:]) - rb := r.Bytes() - fmt.Println("debug:") - fmt.Println(r, s) - for len(rb) < 32 { - rb = append([]byte{0}, rb...) - } - sb := s.Bytes() - for len(rb) < 32 { - sb = append([]byte{0}, sb...) - } - sig = append(rb, sb...) - } - return sig -} - func issueNonce(w http.ResponseWriter, r *http.Request) { b := make([]byte, 16) _, _ = rand.Read(b) diff --git a/mockid/mockid_test.go b/mockid/mockid_test.go index e40171d..bc672db 100644 --- a/mockid/mockid_test.go +++ b/mockid/mockid_test.go @@ -61,6 +61,40 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } +//func TestSelfSignWithoutExp(t *testing.T) +//func TestSelfSignWithJTIWithoutExp(t *testing.T) + +func TestSelfSign(t *testing.T) { + client := srv.Client() + //urlstr, _ := url.Parse(srv.URL + "/jose.jws.json") + urlstr, _ := url.Parse(srv.URL + "/jose.jws.jwt") + + //fmt.Println("URL:", srv.URL, urlstr) + tokenRequest := []byte(`{"seed":"test","header":{"_jwk":true},"claims":{"sub":"bananas","exp":"10m"}}`) + res, err := client.Do(&http.Request{ + Method: "POST", + URL: urlstr, + Body: ioutil.NopCloser(bytes.NewReader(tokenRequest)), + }) + if nil != err { + t.Error(err) + return + } + if 200 != res.StatusCode { + t.Error(fmt.Errorf("bad status code: %d", res.StatusCode)) + return + } + + data, err := ioutil.ReadAll(res.Body) + if nil != err { + t.Error(err) + return + } + + log.Printf("TODO: verify, and verify non-self-signed") + log.Printf(string(data)) +} + func TestGenerateJWK(t *testing.T) { client := srv.Client() urlstr, _ := url.Parse(srv.URL + "/private.jwk.json") @@ -74,6 +108,10 @@ func TestGenerateJWK(t *testing.T) { t.Error(err) return } + if 200 != res.StatusCode { + t.Error(fmt.Errorf("bad status code: %d", res.StatusCode)) + return + } data, err := ioutil.ReadAll(res.Body) if nil != err { @@ -128,6 +166,11 @@ func TestGenWithSeed(t *testing.T) { t.Error(err) return } + if 200 != res.StatusCode { + t.Error(fmt.Errorf("bad status code: %d", res.StatusCode)) + return + } + dataA, err := ioutil.ReadAll(res.Body) if nil != err { //t.Fatal(err) @@ -150,6 +193,11 @@ func TestGenWithSeed(t *testing.T) { t.Error(err) return } + if 200 != res.StatusCode { + t.Error(fmt.Errorf("bad status code: %d", res.StatusCode)) + return + } + dataB, err := ioutil.ReadAll(res.Body) if nil != err { //t.Fatal(err) @@ -180,6 +228,11 @@ func TestGenWithRand(t *testing.T) { t.Error(err) return } + if 200 != res.StatusCode { + t.Error(fmt.Errorf("bad status code: %d", res.StatusCode)) + return + } + dataA, err := ioutil.ReadAll(res.Body) if nil != err { //t.Fatal(err) @@ -200,6 +253,11 @@ func TestGenWithRand(t *testing.T) { t.Error(err) return } + if 200 != res.StatusCode { + t.Error(fmt.Errorf("bad status code: %d", res.StatusCode)) + return + } + dataB, err := ioutil.ReadAll(res.Body) if nil != err { //t.Fatal(err) @@ -226,6 +284,10 @@ func TestGeneratePEM(t *testing.T) { t.Error(err) return } + if 200 != res.StatusCode { + t.Error(fmt.Errorf("bad status code: %d", res.StatusCode)) + return + } data, err := ioutil.ReadAll(res.Body) if nil != err { @@ -266,6 +328,10 @@ func TestPublicJWKWithKey(t *testing.T) { t.Error(err) return } + if 200 != res.StatusCode { + t.Error(fmt.Errorf("bad status code: %d", res.StatusCode)) + return + } data, err := ioutil.ReadAll(res.Body) if nil != err { @@ -319,6 +385,10 @@ func TestPublicPEMWithSeed(t *testing.T) { t.Error(err) return } + if 200 != res.StatusCode { + t.Error(fmt.Errorf("bad status code: %d", res.StatusCode)) + return + } data, err := ioutil.ReadAll(res.Body) if nil != err { diff --git a/mockid/route.go b/mockid/route.go index f3fc36c..d627344 100644 --- a/mockid/route.go +++ b/mockid/route.go @@ -181,6 +181,7 @@ func Route(jwksPrefix string, privkey keypairs.PrivateKey) http.Handler { fmt.Fprintf(w, token) }) + // TODO add /debug prefix http.HandleFunc("/private.jwk.json", api.GeneratePrivateJWK) http.HandleFunc("/priv.der", api.GeneratePrivateDER) http.HandleFunc("/priv.pem", api.GeneratePrivatePEM) @@ -189,6 +190,9 @@ func Route(jwksPrefix string, privkey keypairs.PrivateKey) http.Handler { http.HandleFunc("/pub.der", api.GeneratePublicDER) http.HandleFunc("/pub.pem", api.GeneratePublicPEM) + http.HandleFunc("/jose.jws.json", api.SignJWS) + http.HandleFunc("/jose.jws.jwt", api.SignJWT) + http.HandleFunc("/inspect_token", func(w http.ResponseWriter, r *http.Request) { token := r.Header.Get("Authorization") log.Printf("%s %s %s\n", r.Method, r.URL.Path, token) diff --git a/xkeypairs/parse.go b/xkeypairs/parse.go index d561404..41efa3a 100644 --- a/xkeypairs/parse.go +++ b/xkeypairs/parse.go @@ -16,6 +16,7 @@ func ParseDuration(exp string) (int, error) { if "" == exp { exp = "15m" } + mult := 1 switch exp[len(exp)-1] { case 'w': @@ -37,9 +38,11 @@ func ParseDuration(exp string) (int, error) { exp += "s" } + // 15m => num=15, mult=1*60 num, err := strconv.Atoi(exp[:len(exp)-1]) if nil != err { return 0, err } + return num * mult, nil } diff --git a/xkeypairs/sign.go b/xkeypairs/sign.go new file mode 100644 index 0000000..b4a987b --- /dev/null +++ b/xkeypairs/sign.go @@ -0,0 +1,170 @@ +package xkeypairs + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + mathrand "math/rand" + "time" + + "git.rootprojects.org/root/keypairs" +) + +var RandomReader = rand.Reader + +type JWS struct { + Header Object `json:"header"` // JSON + Claims Object `json:"claims"` // JSON + Protected string `json:"protected"` // base64 + Payload string `json:"payload"` // base64 + Signature string `json:"signature"` // base64 +} + +type Object = map[string]interface{} + +func SignClaims(privkey keypairs.PrivateKey, header Object, claims Object) (*JWS, error) { + var randsrc io.Reader = RandomReader + seed, _ := header["_seed"].(int64) + if 0 != seed { + randsrc = mathrand.New(mathrand.NewSource(seed)) + //delete(header, "_seed") + } + + protected, err := headerToProtected(keypairs.NewPublicKey(privkey.Public()), header) + if nil != err { + return nil, err + } + protected64 := base64.RawURLEncoding.EncodeToString(protected) + + payload, err := claimsToPayload(claims) + if nil != err { + return nil, err + } + payload64 := base64.RawURLEncoding.EncodeToString(payload) + + hash := sha256.Sum256([]byte(fmt.Sprintf( + `%s.%s`, + protected64, + payload64, + ))) + + sig := Sign(randsrc, privkey, hash[:]) + sig64 := base64.RawURLEncoding.EncodeToString(sig) + + return &JWS{ + Header: header, + Claims: claims, + Protected: protected64, + Payload: payload64, + Signature: sig64, + }, nil +} + +func headerToProtected(pub keypairs.PublicKey, header Object) ([]byte, error) { + if nil == header { + header = Object{} + } + + // Only supporting 2048-bit and P256 keys right now + // because that's all that's practical and well-supported. + // No security theatre here. + alg := "ES256" + switch pub.Key().(type) { + case *rsa.PublicKey: + alg = "RS256" + } + + if selfSign, _ := header["_jwk"].(bool); selfSign { + delete(header, "_jwk") + any := Object{} + _ = json.Unmarshal(keypairs.MarshalJWKPublicKey(pub), &any) + header["jwk"] = any + } + + // TODO what are the acceptable values? JWT. JWS? others? + header["typ"] = "JWT" + if _, ok := header["jwk"]; !ok { + thumbprint := keypairs.ThumbprintPublicKey(pub) + kid, _ := header["kid"].(string) + if "" != kid && thumbprint != kid { + return nil, errors.New("'kid' should be the key's thumbprint") + } + header["kid"] = thumbprint + } + header["alg"] = alg + + protected, err := json.Marshal(header) + if nil != err { + return nil, err + } + return protected, nil +} + +func claimsToPayload(claims Object) ([]byte, error) { + if nil == claims { + claims = Object{} + } + + jti, _ := claims["jti"].(string) + exp, _ := claims["exp"].(int64) + dur, _ := claims["exp"].(string) + insecure, _ := claims["insecure"].(bool) + + // parse if exp is actually a duration, such as "15m" + if 0 == exp && "" != dur { + s, err := ParseDuration(dur) + if nil != err { + return nil, err + } + exp = time.Now().Add(time.Duration(s) * time.Second).Unix() + claims["exp"] = exp + } + if "" == jti && 0 == exp && !insecure { + return nil, errors.New("token must have jti or exp as to be expirable / cancellable") + } + + return json.Marshal(claims) +} + +func JWSToJWT(jwt *JWS) string { + return fmt.Sprintf( + "%s.%s.%s", + jwt.Protected, + jwt.Payload, + jwt.Signature, + ) +} + +func Sign(rand io.Reader, privkey keypairs.PrivateKey, hash []byte) []byte { + var sig []byte + + if len(hash) != 32 { + panic("only 256-bit hashes for 2048-bit and 256-bit keys are supported") + } + + switch k := privkey.(type) { + case *rsa.PrivateKey: + sig, _ = rsa.SignPKCS1v15(rand, k, crypto.SHA256, hash) + case *ecdsa.PrivateKey: + r, s, _ := ecdsa.Sign(rand, k, hash[:]) + rb := r.Bytes() + fmt.Println("debug:") + fmt.Println(r, s) + for len(rb) < 32 { + rb = append([]byte{0}, rb...) + } + sb := s.Bytes() + for len(rb) < 32 { + sb = append([]byte{0}, sb...) + } + sig = append(rb, sb...) + } + return sig +}