refactor
This commit is contained in:
parent
d6f5027480
commit
190ee93da8
|
@ -1,3 +1,4 @@
|
|||
/public-jwks
|
||||
/go-mockid
|
||||
|
||||
# ---> Go
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
module git.coolaj86.com/coolaj86/go-mockid
|
||||
|
||||
go 1.12
|
||||
|
||||
require github.com/joho/godotenv v1.3.0
|
|
@ -0,0 +1,2 @@
|
|||
github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc=
|
||||
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
|
385
mockid.go
385
mockid.go
|
@ -1,46 +1,20 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.coolaj86.com/coolaj86/go-mockid/mockid"
|
||||
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
)
|
||||
|
||||
type PrivateJWK struct {
|
||||
PublicJWK
|
||||
D string `json:"d"`
|
||||
}
|
||||
type PublicJWK struct {
|
||||
Crv string `json:"crv"`
|
||||
KeyID string `json:"kid,omitempty"`
|
||||
Kty string `json:"kty,omitempty"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
var nonces map[string]int64
|
||||
var jwksPrefix string
|
||||
|
||||
func init() {
|
||||
nonces = make(map[string]int64)
|
||||
}
|
||||
|
||||
func main() {
|
||||
done := make(chan bool)
|
||||
var port int
|
||||
|
@ -52,17 +26,15 @@ func main() {
|
|||
"x": "ToL2HppsTESXQKvp7ED6NMgV4YnwbMeONexNry3KDNQ",
|
||||
"y": "Tt6Q3rxU37KAinUV9PLMlwosNy1t3Bf2VDg5q955AGc",
|
||||
}
|
||||
jwk := &PrivateJWK{
|
||||
PublicJWK: PublicJWK{
|
||||
jwk := &mockid.PrivateJWK{
|
||||
PublicJWK: mockid.PublicJWK{
|
||||
Crv: jwkm["crv"],
|
||||
X: jwkm["x"],
|
||||
Y: jwkm["y"],
|
||||
},
|
||||
D: jwkm["d"],
|
||||
}
|
||||
priv := parseKey(jwk)
|
||||
pub := &priv.PublicKey
|
||||
thumbprint := thumbprintKey(pub)
|
||||
priv := mockid.ParseKey(jwk)
|
||||
|
||||
portFlag := flag.Int("port", 0, "Port on which the HTTP server should run")
|
||||
urlFlag := flag.String("url", "", "Outward-facing address, such as https://example.com")
|
||||
|
@ -86,6 +58,7 @@ func main() {
|
|||
host = "http://localhost:" + strconv.Itoa(port)
|
||||
}
|
||||
|
||||
var jwksPrefix string
|
||||
if nil != prefixFlag && "" != *prefixFlag {
|
||||
jwksPrefix = *prefixFlag
|
||||
} else {
|
||||
|
@ -97,206 +70,7 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
http.HandleFunc("/api/new-nonce", func(w http.ResponseWriter, r *http.Request) {
|
||||
baseURL := getBaseURL(r)
|
||||
/*
|
||||
res.statusCode = 200;
|
||||
res.setHeader("Cache-Control", "max-age=0, no-cache, no-store");
|
||||
// TODO
|
||||
//res.setHeader("Date", "Sun, 10 Mar 2019 08:04:45 GMT");
|
||||
// is this the expiration of the nonce itself? methinks maybe so
|
||||
//res.setHeader("Expires", "Sun, 10 Mar 2019 08:04:45 GMT");
|
||||
// TODO use one of the registered domains
|
||||
//var indexUrl = "https://acme-staging-v02.api.letsencrypt.org/index"
|
||||
*/
|
||||
//var port = (state.config.ipc && state.config.ipc.port || state._ipc.port || undefined);
|
||||
//var indexUrl = "http://localhost:" + port + "/index";
|
||||
indexUrl := baseURL + "/index";
|
||||
w.Header().Set("Link", "<" + indexUrl + ">;rel=\"index\"");
|
||||
w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store");
|
||||
w.Header().Set("Pragma", "no-cache");
|
||||
//res.setHeader("Strict-Transport-Security", "max-age=604800");
|
||||
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
issueNonce(w, r)
|
||||
})
|
||||
|
||||
http.HandleFunc("/api/new-account", requireNonce(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Not Implemented", http.StatusNotImplemented)
|
||||
}))
|
||||
|
||||
http.HandleFunc("/api/jwks", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path)
|
||||
if "POST" != r.Method {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
tok := make(map[string]interface{})
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
err := decoder.Decode(&tok)
|
||||
if nil != err {
|
||||
http.Error(w, "Bad Request: invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// TODO better, JSON error messages
|
||||
if _, ok := tok["d"]; ok {
|
||||
http.Error(w, "Bad Request: private key", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
kty, _ := tok["kty"].(string)
|
||||
if "EC" != kty {
|
||||
http.Error(w, "Bad Request: only EC keys are supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
crv, ok := tok["crv"].(string)
|
||||
if 5 != len(crv) || "P-" != crv[:2] {
|
||||
http.Error(w, "Bad Request: bad curve", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
x, ok := tok["x"].(string)
|
||||
if !ok {
|
||||
http.Error(w, "Bad Request: missing 'x'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
y, ok := tok["y"].(string)
|
||||
if !ok {
|
||||
http.Error(w, "Bad Request: missing 'y'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO RSA
|
||||
thumbprintable := []byte(
|
||||
fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, crv, x, y),
|
||||
)
|
||||
|
||||
var thumb []byte
|
||||
switch crv[2:] {
|
||||
case "256":
|
||||
hash := sha256.Sum256(thumbprintable)
|
||||
thumb = hash[:]
|
||||
case "384":
|
||||
hash := sha512.Sum384(thumbprintable)
|
||||
thumb = hash[:]
|
||||
case "521":
|
||||
hash := sha512.Sum512(thumbprintable)
|
||||
thumb = hash[:]
|
||||
default:
|
||||
http.Error(w, "Bad Request: bad curve", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
kid := base64.RawURLEncoding.EncodeToString(thumb)
|
||||
if kid2, _ := tok["kid"].(string); "" != kid2 && kid != kid2 {
|
||||
http.Error(w, "Bad Request: kid should be "+kid, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO allow posting at the top-level?
|
||||
// TODO support a group of keys by PPID
|
||||
// (right now it's only by KID)
|
||||
if !strings.HasPrefix(r.Host, strings.ToLower(kid)+".") {
|
||||
http.Error(w, "Bad Request: prefix should be "+kid, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
pub := []byte(fmt.Sprintf(
|
||||
`{"crv":%q,"kid":%q,"kty":"EC","x":%q,"y":%q}`, crv, kid, x, y,
|
||||
))
|
||||
err = ioutil.WriteFile(
|
||||
filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json"),
|
||||
pub,
|
||||
0644,
|
||||
)
|
||||
if nil != err {
|
||||
fmt.Println("can't write file")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := getBaseURL(r)
|
||||
w.Write([]byte(fmt.Sprintf(
|
||||
`{ "iss":%q, "jwks_url":%q }`, baseURL+"/", baseURL+"/.well-known/jwks.json",
|
||||
)))
|
||||
})
|
||||
|
||||
http.HandleFunc("/access_token", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%s %s\n", r.Method, r.URL.Path)
|
||||
_, _, token := genToken(getBaseURL(r), priv, r.URL.Query())
|
||||
fmt.Fprintf(w, token)
|
||||
})
|
||||
|
||||
http.HandleFunc("/authorization_header", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%s %s\n", r.Method, r.URL.Path)
|
||||
|
||||
var header string
|
||||
headers, _ := r.URL.Query()["header"]
|
||||
if 0 == len(headers) {
|
||||
header = "Authorization"
|
||||
} else {
|
||||
header = headers[0]
|
||||
}
|
||||
|
||||
var prefix string
|
||||
prefixes, _ := r.URL.Query()["prefix"]
|
||||
if 0 == len(prefixes) {
|
||||
prefix = "Bearer "
|
||||
} else {
|
||||
prefix = prefixes[0]
|
||||
}
|
||||
|
||||
_, _, token := genToken(getBaseURL(r), priv, r.URL.Query())
|
||||
fmt.Fprintf(w, "%s: %s%s", header, prefix, token)
|
||||
})
|
||||
|
||||
http.HandleFunc("/key.jwk.json", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%s %s", r.Method, r.URL.Path)
|
||||
fmt.Fprintf(w, `{ "kty": "EC" , "crv": %q , "d": %q , "x": %q , "y": %q , "ext": true , "key_ops": ["sign"] }`, jwk.Crv, jwk.D, jwk.X, jwk.Y)
|
||||
})
|
||||
|
||||
http.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||
baseURL := getBaseURL(r)
|
||||
log.Printf("%s %s\n", r.Method, r.URL.Path)
|
||||
fmt.Fprintf(w, `{ "issuer": "%s", "jwks_uri": "%s/.well-known/jwks.json" }`, baseURL, baseURL)
|
||||
})
|
||||
|
||||
http.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path)
|
||||
parts := strings.Split(r.Host, ".")
|
||||
kid := parts[0]
|
||||
|
||||
b, err := ioutil.ReadFile(filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json"))
|
||||
if nil != err {
|
||||
//http.Error(w, "Not Found", http.StatusNotFound)
|
||||
jwkstr := fmt.Sprintf(
|
||||
`{ "keys": [ { "kty": "EC" , "crv": %q , "x": %q , "y": %q , "kid": %q , "ext": true , "key_ops": ["verify"] , "exp": %s } ] }`,
|
||||
jwk.Crv, jwk.X, jwk.Y, thumbprint, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10),
|
||||
)
|
||||
fmt.Println(jwkstr)
|
||||
fmt.Fprintf(w, jwkstr)
|
||||
return
|
||||
}
|
||||
|
||||
tok := &PublicJWK{}
|
||||
err = json.Unmarshal(b, tok)
|
||||
if nil != err {
|
||||
// TODO delete the bad file?
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
jwkstr := fmt.Sprintf(
|
||||
`{ "keys": [ { "kty": "EC", "crv": %q, "x": %q, "y": %q, "kid": %q,`+
|
||||
` "ext": true, "key_ops": ["verify"], "exp": %s } ] }`,
|
||||
tok.Crv, tok.X, tok.Y, tok.KeyID, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10),
|
||||
)
|
||||
fmt.Println(jwkstr)
|
||||
fmt.Fprintf(w, jwkstr)
|
||||
})
|
||||
mockid.Route(jwksPrefix, priv, jwk)
|
||||
|
||||
fs := http.FileServer(http.Dir("public"))
|
||||
http.Handle("/", fs)
|
||||
|
@ -317,149 +91,10 @@ func main() {
|
|||
fmt.Printf("Private Key:\n\t%s\n", string(b))
|
||||
b, _ = json.Marshal(jwk.PublicJWK)
|
||||
fmt.Printf("Public Key:\n\t%s\n", string(b))
|
||||
protected, payload, token := genToken(host, priv, url.Values{})
|
||||
protected, payload, token := mockid.GenToken(host, priv, url.Values{})
|
||||
fmt.Printf("Protected (Header):\n\t%s\n", protected)
|
||||
fmt.Printf("Payload (Claims):\n\t%s\n", payload)
|
||||
fmt.Printf("Access Token:\n\t%s\n", token)
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
func parseExp(exp string) (int, error) {
|
||||
if "" == exp {
|
||||
exp = "15m"
|
||||
}
|
||||
mult := 1
|
||||
switch exp[len(exp)-1] {
|
||||
case 'w':
|
||||
mult *= 7
|
||||
fallthrough
|
||||
case 'd':
|
||||
mult *= 24
|
||||
fallthrough
|
||||
case 'h':
|
||||
mult *= 60
|
||||
fallthrough
|
||||
case 'm':
|
||||
mult *= 60
|
||||
fallthrough
|
||||
case 's':
|
||||
// no fallthrough
|
||||
default:
|
||||
// could be 'k' or 'z', but we assume its empty
|
||||
exp += "s"
|
||||
}
|
||||
|
||||
num, err := strconv.Atoi(exp[:len(exp)-1])
|
||||
if nil != err {
|
||||
return 0, err
|
||||
}
|
||||
return num * mult, nil
|
||||
}
|
||||
|
||||
func genToken(host string, priv *ecdsa.PrivateKey, query url.Values) (string, string, string) {
|
||||
thumbprint := thumbprintKey(&priv.PublicKey)
|
||||
protected := fmt.Sprintf(`{"typ":"JWT","alg":"ES256","kid":"%s"}`, thumbprint)
|
||||
protected64 := base64.RawURLEncoding.EncodeToString([]byte(protected))
|
||||
|
||||
exp, err := parseExp(query.Get("exp"))
|
||||
if nil != err {
|
||||
// cryptic error code
|
||||
// TODO propagate error
|
||||
exp = 422
|
||||
}
|
||||
|
||||
payload := fmt.Sprintf(
|
||||
`{"iss":"%s/","sub":"dummy","exp":%s}`,
|
||||
host, strconv.FormatInt(time.Now().Add(time.Duration(exp)*time.Second).Unix(), 10),
|
||||
)
|
||||
payload64 := base64.RawURLEncoding.EncodeToString([]byte(payload))
|
||||
|
||||
hash := sha256.Sum256([]byte(fmt.Sprintf(`%s.%s`, protected64, payload64)))
|
||||
r, s, _ := ecdsa.Sign(rand.Reader, priv, hash[:])
|
||||
rb := r.Bytes()
|
||||
for len(rb) < 32 {
|
||||
rb = append([]byte{0}, rb...)
|
||||
}
|
||||
sb := s.Bytes()
|
||||
for len(rb) < 32 {
|
||||
sb = append([]byte{0}, sb...)
|
||||
}
|
||||
sig64 := base64.RawURLEncoding.EncodeToString(append(rb, sb...))
|
||||
token := fmt.Sprintf(`%s.%s.%s`, protected64, payload64, sig64)
|
||||
return protected, payload, token
|
||||
}
|
||||
|
||||
func parseKey(jwk *PrivateJWK) *ecdsa.PrivateKey {
|
||||
xb, _ := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
xi := &big.Int{}
|
||||
xi.SetBytes(xb)
|
||||
yb, _ := base64.RawURLEncoding.DecodeString(jwk.Y)
|
||||
yi := &big.Int{}
|
||||
yi.SetBytes(yb)
|
||||
pub := &ecdsa.PublicKey{
|
||||
Curve: elliptic.P256(),
|
||||
X: xi,
|
||||
Y: yi,
|
||||
}
|
||||
|
||||
db, _ := base64.RawURLEncoding.DecodeString(jwk.D)
|
||||
di := &big.Int{}
|
||||
di.SetBytes(db)
|
||||
priv := &ecdsa.PrivateKey{
|
||||
PublicKey: *pub,
|
||||
D: di,
|
||||
}
|
||||
return priv
|
||||
}
|
||||
|
||||
func thumbprintKey(pub *ecdsa.PublicKey) string {
|
||||
minpub := []byte(fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, "P-256", pub.X, pub.Y))
|
||||
sha := sha256.Sum256(minpub)
|
||||
return base64.RawURLEncoding.EncodeToString(sha[:])
|
||||
}
|
||||
|
||||
func issueNonce(w http.ResponseWriter, r *http.Request) {
|
||||
b := make([]byte, 16)
|
||||
_, _ = rand.Read(b)
|
||||
nonce := base64.RawURLEncoding.EncodeToString(b);
|
||||
nonces[nonce] = time.Now().Unix()
|
||||
|
||||
w.Header().Set("Replay-Nonce", nonce);
|
||||
}
|
||||
|
||||
|
||||
func requireNonce(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func (w http.ResponseWriter, r *http.Request) {
|
||||
nonce := r.Header.Get("Replay-Nonce")
|
||||
// TODO expire nonces every so often
|
||||
t := nonces[nonce]
|
||||
if 0 == t {
|
||||
http.Error(
|
||||
w,
|
||||
`{ "error": "invalid or expired nonce", "error_code": "ENONCE" }`,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
delete(nonces, nonce)
|
||||
issueNonce(w, r)
|
||||
|
||||
next(w, r);
|
||||
}
|
||||
}
|
||||
|
||||
func getBaseURL(r *http.Request) string {
|
||||
var scheme string
|
||||
if nil != r.TLS || "https" == r.Header.Get("X-Forwarded-Proto") {
|
||||
scheme = "https:"
|
||||
} else {
|
||||
scheme = "http:"
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"%s//%s",
|
||||
scheme,
|
||||
r.Host,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,460 @@
|
|||
package mockid
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type PrivateJWK struct {
|
||||
PublicJWK
|
||||
D string `json:"d"`
|
||||
}
|
||||
|
||||
type PublicJWK struct {
|
||||
Crv string `json:"crv"`
|
||||
KeyID string `json:"kid,omitempty"`
|
||||
Kty string `json:"kty,omitempty"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
var nonces map[string]int64
|
||||
|
||||
func init() {
|
||||
nonces = make(map[string]int64)
|
||||
}
|
||||
|
||||
func Route(jwksPrefix string , priv *ecdsa.PrivateKey, jwk *PrivateJWK) {
|
||||
pub := &priv.PublicKey
|
||||
thumbprint := thumbprintKey(pub)
|
||||
|
||||
http.HandleFunc("/api/new-nonce", func(w http.ResponseWriter, r *http.Request) {
|
||||
baseURL := getBaseURL(r)
|
||||
/*
|
||||
res.statusCode = 200;
|
||||
res.setHeader("Cache-Control", "max-age=0, no-cache, no-store");
|
||||
// TODO
|
||||
//res.setHeader("Date", "Sun, 10 Mar 2019 08:04:45 GMT");
|
||||
// is this the expiration of the nonce itself? methinks maybe so
|
||||
//res.setHeader("Expires", "Sun, 10 Mar 2019 08:04:45 GMT");
|
||||
// TODO use one of the registered domains
|
||||
//var indexUrl = "https://acme-staging-v02.api.letsencrypt.org/index"
|
||||
*/
|
||||
//var port = (state.config.ipc && state.config.ipc.port || state._ipc.port || undefined);
|
||||
//var indexUrl = "http://localhost:" + port + "/index";
|
||||
indexUrl := baseURL + "/index"
|
||||
w.Header().Set("Link", "<"+indexUrl+">;rel=\"index\"")
|
||||
w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
//res.setHeader("Strict-Transport-Security", "max-age=604800");
|
||||
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
issueNonce(w, r)
|
||||
})
|
||||
|
||||
http.HandleFunc("/api/new-account", requireNonce(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Not Implemented", http.StatusNotImplemented)
|
||||
}))
|
||||
|
||||
http.HandleFunc("/api/jwks", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path)
|
||||
if "POST" != r.Method {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
tok := make(map[string]interface{})
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
err := decoder.Decode(&tok)
|
||||
if nil != err {
|
||||
http.Error(w, "Bad Request: invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// TODO better, JSON error messages
|
||||
if _, ok := tok["d"]; ok {
|
||||
http.Error(w, "Bad Request: private key", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
kty, _ := tok["kty"].(string)
|
||||
switch kty {
|
||||
case "EC":
|
||||
postEC(jwksPrefix, tok, w, r)
|
||||
case "RSA":
|
||||
postRSA(jwksPrefix, tok, w, r)
|
||||
default:
|
||||
http.Error(w, "Bad Request: only EC and RSA keys are supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
http.HandleFunc("/access_token", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%s %s\n", r.Method, r.URL.Path)
|
||||
_, _, token := GenToken(getBaseURL(r), priv, r.URL.Query())
|
||||
fmt.Fprintf(w, token)
|
||||
})
|
||||
|
||||
http.HandleFunc("/authorization_header", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%s %s\n", r.Method, r.URL.Path)
|
||||
|
||||
var header string
|
||||
headers, _ := r.URL.Query()["header"]
|
||||
if 0 == len(headers) {
|
||||
header = "Authorization"
|
||||
} else {
|
||||
header = headers[0]
|
||||
}
|
||||
|
||||
var prefix string
|
||||
prefixes, _ := r.URL.Query()["prefix"]
|
||||
if 0 == len(prefixes) {
|
||||
prefix = "Bearer "
|
||||
} else {
|
||||
prefix = prefixes[0]
|
||||
}
|
||||
|
||||
_, _, token := GenToken(getBaseURL(r), priv, r.URL.Query())
|
||||
fmt.Fprintf(w, "%s: %s%s", header, prefix, token)
|
||||
})
|
||||
|
||||
http.HandleFunc("/key.jwk.json", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%s %s", r.Method, r.URL.Path)
|
||||
fmt.Fprintf(w, `{ "kty": "EC" , "crv": %q , "d": %q , "x": %q , "y": %q , "ext": true , "key_ops": ["sign"] }`, jwk.Crv, jwk.D, jwk.X, jwk.Y)
|
||||
})
|
||||
|
||||
http.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||
baseURL := getBaseURL(r)
|
||||
log.Printf("%s %s\n", r.Method, r.URL.Path)
|
||||
fmt.Fprintf(w, `{ "issuer": "%s", "jwks_uri": "%s/.well-known/jwks.json" }`, baseURL, baseURL)
|
||||
})
|
||||
|
||||
http.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%s %s %s", r.Method, r.Host, r.URL.Path)
|
||||
parts := strings.Split(r.Host, ".")
|
||||
kid := parts[0]
|
||||
|
||||
b, err := ioutil.ReadFile(filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json"))
|
||||
if nil != err {
|
||||
//http.Error(w, "Not Found", http.StatusNotFound)
|
||||
jwkstr := fmt.Sprintf(
|
||||
`{ "keys": [ { "kty": "EC" , "crv": %q , "x": %q , "y": %q , "kid": %q , "ext": true , "key_ops": ["verify"] , "exp": %s } ] }`,
|
||||
jwk.Crv, jwk.X, jwk.Y, thumbprint, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10),
|
||||
)
|
||||
fmt.Println(jwkstr)
|
||||
fmt.Fprintf(w, jwkstr)
|
||||
return
|
||||
}
|
||||
|
||||
tok := &PublicJWK{}
|
||||
err = json.Unmarshal(b, tok)
|
||||
if nil != err {
|
||||
// TODO delete the bad file?
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
jwkstr := fmt.Sprintf(
|
||||
`{ "keys": [ { "kty": "EC", "crv": %q, "x": %q, "y": %q, "kid": %q,`+
|
||||
` "ext": true, "key_ops": ["verify"], "exp": %s } ] }`,
|
||||
tok.Crv, tok.X, tok.Y, tok.KeyID, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10),
|
||||
)
|
||||
fmt.Println(jwkstr)
|
||||
fmt.Fprintf(w, jwkstr)
|
||||
})
|
||||
}
|
||||
|
||||
func parseExp(exp string) (int, error) {
|
||||
if "" == exp {
|
||||
exp = "15m"
|
||||
}
|
||||
mult := 1
|
||||
switch exp[len(exp)-1] {
|
||||
case 'w':
|
||||
mult *= 7
|
||||
fallthrough
|
||||
case 'd':
|
||||
mult *= 24
|
||||
fallthrough
|
||||
case 'h':
|
||||
mult *= 60
|
||||
fallthrough
|
||||
case 'm':
|
||||
mult *= 60
|
||||
fallthrough
|
||||
case 's':
|
||||
// no fallthrough
|
||||
default:
|
||||
// could be 'k' or 'z', but we assume its empty
|
||||
exp += "s"
|
||||
}
|
||||
|
||||
num, err := strconv.Atoi(exp[:len(exp)-1])
|
||||
if nil != err {
|
||||
return 0, err
|
||||
}
|
||||
return num * mult, nil
|
||||
}
|
||||
|
||||
func postEC(jwksPrefix string, tok map[string]interface{}, w http.ResponseWriter, r *http.Request) {
|
||||
crv, ok := tok["crv"].(string)
|
||||
if 5 != len(crv) || "P-" != crv[:2] {
|
||||
http.Error(w, "Bad Request: bad curve", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
x, ok := tok["x"].(string)
|
||||
if !ok {
|
||||
http.Error(w, "Bad Request: missing 'x'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
y, ok := tok["y"].(string)
|
||||
if !ok {
|
||||
http.Error(w, "Bad Request: missing 'y'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
thumbprintable := []byte(
|
||||
fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, crv, x, y),
|
||||
)
|
||||
alg := crv[2:]
|
||||
|
||||
var thumb []byte
|
||||
switch alg {
|
||||
case "256":
|
||||
hash := sha256.Sum256(thumbprintable)
|
||||
thumb = hash[:]
|
||||
case "384":
|
||||
hash := sha512.Sum384(thumbprintable)
|
||||
thumb = hash[:]
|
||||
case "521":
|
||||
fallthrough
|
||||
case "512":
|
||||
hash := sha512.Sum512(thumbprintable)
|
||||
thumb = hash[:]
|
||||
default:
|
||||
http.Error(w, "Bad Request: bad key length or curve", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
kid := base64.RawURLEncoding.EncodeToString(thumb)
|
||||
if kid2, _ := tok["kid"].(string); "" != kid2 && kid != kid2 {
|
||||
http.Error(w, "Bad Request: kid should be "+kid, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
pub := []byte(fmt.Sprintf(
|
||||
`{"crv":%q,"kid":%q,"kty":"EC","x":%q,"y":%q}`, crv, kid, x, y,
|
||||
))
|
||||
|
||||
// TODO allow posting at the top-level?
|
||||
// TODO support a group of keys by PPID
|
||||
// (right now it's only by KID)
|
||||
if !strings.HasPrefix(r.Host, strings.ToLower(kid)+".") {
|
||||
http.Error(w, "Bad Request: prefix should be "+kid, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(
|
||||
filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json"),
|
||||
pub,
|
||||
0644,
|
||||
); nil != err {
|
||||
fmt.Println("can't write file")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := getBaseURL(r)
|
||||
w.Write([]byte(fmt.Sprintf(
|
||||
`{ "iss":%q, "jwks_url":%q }`, baseURL+"/", baseURL+"/.well-known/jwks.json",
|
||||
)))
|
||||
}
|
||||
|
||||
func postRSA(jwksPrefix string, tok map[string]interface{}, w http.ResponseWriter, r *http.Request) {
|
||||
e, ok := tok["e"].(string)
|
||||
if !ok {
|
||||
http.Error(w, "Bad Request: missing 'e'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
n, ok := tok["n"].(string)
|
||||
if !ok {
|
||||
http.Error(w, "Bad Request: missing 'n'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
thumbprintable := []byte(
|
||||
fmt.Sprintf(`{"e":%q,"kty":"RSA","n":%q}`, e, n),
|
||||
)
|
||||
|
||||
var thumb []byte
|
||||
// TODO handle bit lengths well
|
||||
switch 3 * (len(n) / 4.0) {
|
||||
case 256:
|
||||
hash := sha256.Sum256(thumbprintable)
|
||||
thumb = hash[:]
|
||||
case 384:
|
||||
hash := sha512.Sum384(thumbprintable)
|
||||
thumb = hash[:]
|
||||
case 512:
|
||||
hash := sha512.Sum512(thumbprintable)
|
||||
thumb = hash[:]
|
||||
default:
|
||||
http.Error(w, "Bad Request: only standard RSA key lengths (2048, 3072, 4096) are supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
kid := base64.RawURLEncoding.EncodeToString(thumb)
|
||||
if kid2, _ := tok["kid"].(string); "" != kid2 && kid != kid2 {
|
||||
http.Error(w, "Bad Request: kid should be "+kid, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
pub := []byte(fmt.Sprintf(
|
||||
`{"e":%q,"kid":%q,"kty":"EC","n":%q}`, e, kid, n,
|
||||
))
|
||||
|
||||
// TODO allow posting at the top-level?
|
||||
// TODO support a group of keys by PPID
|
||||
// (right now it's only by KID)
|
||||
if !strings.HasPrefix(r.Host, strings.ToLower(kid)+".") {
|
||||
http.Error(w, "Bad Request: prefix should be "+kid, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(
|
||||
filepath.Join(jwksPrefix, strings.ToLower(kid)+".jwk.json"),
|
||||
pub,
|
||||
0644,
|
||||
); nil != err {
|
||||
fmt.Println("can't write file")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := getBaseURL(r)
|
||||
w.Write([]byte(fmt.Sprintf(
|
||||
`{ "iss":%q, "jwks_url":%q }`, baseURL+"/", baseURL+"/.well-known/jwks.json",
|
||||
)))
|
||||
}
|
||||
|
||||
func GenToken(host string, priv *ecdsa.PrivateKey, query url.Values) (string, string, string) {
|
||||
thumbprint := thumbprintKey(&priv.PublicKey)
|
||||
protected := fmt.Sprintf(`{"typ":"JWT","alg":"ES256","kid":"%s"}`, thumbprint)
|
||||
protected64 := base64.RawURLEncoding.EncodeToString([]byte(protected))
|
||||
|
||||
exp, err := parseExp(query.Get("exp"))
|
||||
if nil != err {
|
||||
// cryptic error code
|
||||
// TODO propagate error
|
||||
exp = 422
|
||||
}
|
||||
|
||||
payload := fmt.Sprintf(
|
||||
`{"iss":"%s/","sub":"dummy","exp":%s}`,
|
||||
host, strconv.FormatInt(time.Now().Add(time.Duration(exp)*time.Second).Unix(), 10),
|
||||
)
|
||||
payload64 := base64.RawURLEncoding.EncodeToString([]byte(payload))
|
||||
|
||||
hash := sha256.Sum256([]byte(fmt.Sprintf(`%s.%s`, protected64, payload64)))
|
||||
r, s, _ := ecdsa.Sign(rand.Reader, priv, hash[:])
|
||||
rb := r.Bytes()
|
||||
for len(rb) < 32 {
|
||||
rb = append([]byte{0}, rb...)
|
||||
}
|
||||
sb := s.Bytes()
|
||||
for len(rb) < 32 {
|
||||
sb = append([]byte{0}, sb...)
|
||||
}
|
||||
sig64 := base64.RawURLEncoding.EncodeToString(append(rb, sb...))
|
||||
token := fmt.Sprintf(`%s.%s.%s`, protected64, payload64, sig64)
|
||||
return protected, payload, token
|
||||
}
|
||||
|
||||
func ParseKey(jwk *PrivateJWK) *ecdsa.PrivateKey {
|
||||
xb, _ := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
xi := &big.Int{}
|
||||
xi.SetBytes(xb)
|
||||
yb, _ := base64.RawURLEncoding.DecodeString(jwk.Y)
|
||||
yi := &big.Int{}
|
||||
yi.SetBytes(yb)
|
||||
pub := &ecdsa.PublicKey{
|
||||
Curve: elliptic.P256(),
|
||||
X: xi,
|
||||
Y: yi,
|
||||
}
|
||||
|
||||
db, _ := base64.RawURLEncoding.DecodeString(jwk.D)
|
||||
di := &big.Int{}
|
||||
di.SetBytes(db)
|
||||
priv := &ecdsa.PrivateKey{
|
||||
PublicKey: *pub,
|
||||
D: di,
|
||||
}
|
||||
return priv
|
||||
}
|
||||
|
||||
func thumbprintKey(pub *ecdsa.PublicKey) string {
|
||||
minpub := []byte(fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, "P-256", pub.X, pub.Y))
|
||||
sha := sha256.Sum256(minpub)
|
||||
return base64.RawURLEncoding.EncodeToString(sha[:])
|
||||
}
|
||||
|
||||
func issueNonce(w http.ResponseWriter, r *http.Request) {
|
||||
b := make([]byte, 16)
|
||||
_, _ = rand.Read(b)
|
||||
nonce := base64.RawURLEncoding.EncodeToString(b)
|
||||
nonces[nonce] = time.Now().Unix()
|
||||
|
||||
w.Header().Set("Replay-Nonce", nonce)
|
||||
}
|
||||
|
||||
func requireNonce(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
nonce := r.Header.Get("Replay-Nonce")
|
||||
// TODO expire nonces every so often
|
||||
t := nonces[nonce]
|
||||
if 0 == t {
|
||||
http.Error(
|
||||
w,
|
||||
`{ "error": "invalid or expired nonce", "error_code": "ENONCE" }`,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
delete(nonces, nonce)
|
||||
issueNonce(w, r)
|
||||
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func getBaseURL(r *http.Request) string {
|
||||
var scheme string
|
||||
if nil != r.TLS || "https" == r.Header.Get("X-Forwarded-Proto") {
|
||||
scheme = "https:"
|
||||
} else {
|
||||
scheme = "http:"
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"%s//%s",
|
||||
scheme,
|
||||
r.Host,
|
||||
)
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package mockid
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
//keypairs "github.com/big-squid/go-keypairs"
|
||||
//"github.com/big-squid/go-keypairs/keyfetch/uncached"
|
||||
)
|
||||
|
||||
func TestTest(t *testing.T) {
|
||||
t.Fatal("no test")
|
||||
}
|
Loading…
Reference in New Issue