go-mockid/mockid.go

188 lines
5.1 KiB
Go
Raw Normal View History

2019-08-01 06:21:32 +00:00
package main
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"flag"
"fmt"
"log"
"math/big"
"net/http"
"os"
"strconv"
"time"
)
type PrivateJWK struct {
PublicJWK
D string `json:"d"`
}
type PublicJWK struct {
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
}
func main() {
done := make(chan bool)
var port int
var host string
jwkm := map[string]string{
"crv": "P-256",
"d": "GYAwlBHc2mPsj1lp315HbYOmKNJ7esmO3JAkZVn9nJs",
"x": "ToL2HppsTESXQKvp7ED6NMgV4YnwbMeONexNry3KDNQ",
"y": "Tt6Q3rxU37KAinUV9PLMlwosNy1t3Bf2VDg5q955AGc",
}
jwk := &PrivateJWK{
PublicJWK: PublicJWK{
Crv: jwkm["crv"],
X: jwkm["x"],
Y: jwkm["y"],
},
D: jwkm["d"],
}
priv := parseKey(jwk)
pub := &priv.PublicKey
thumbprint := thumbprintKey(pub)
portFlag := flag.Int("port", 0, "Port on which the HTTP server should run")
urlFlag := flag.String("url", "", "Outward-facing address, such as https://example.com")
flag.Parse()
if nil != portFlag && *portFlag > 0 {
port = *portFlag
} else {
portStr := os.Getenv("PORT")
port, _ = strconv.Atoi(portStr)
}
if port < 1 {
fmt.Fprintf(os.Stderr, "You must specify --port or PORT\n")
os.Exit(1)
}
if nil != urlFlag && "" != *urlFlag {
host = *urlFlag
} else {
host = "http://localhost:" + strconv.Itoa(port)
}
http.HandleFunc("/access_token", func(w http.ResponseWriter, r *http.Request) {
log.Printf("%s %s\n", r.Method, r.URL.Path)
var scheme string
if nil != r.TLS || "https" == r.Header.Get("X-Forwarded-Proto") {
scheme = "https://"
} else {
scheme = "http://"
}
_, _, token := genToken(scheme + r.Host, priv)
fmt.Fprintf(w, token)
})
http.HandleFunc("/key.jwk.json", func(w http.ResponseWriter, r *http.Request) {
log.Printf("%s %s", r.Method, r.URL.Path)
fmt.Fprintf(w, `{ "kty": "EC" , "crv": %q , "d": %q , "x": %q , "y": %q , "ext": true , "key_ops": ["sign"] }`, jwk.Crv, jwk.D, jwk.X, jwk.Y)
})
http.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
var scheme string
if nil != r.TLS || "https" == r.Header.Get("X-Forwarded-Proto") {
scheme = "https://"
} else {
scheme = "http://"
}
log.Printf("%s %s\n", r.Method, r.URL.Path)
fmt.Fprintf(w, `{ "issuer": "%s", "jwks_uri": "%s/.well-known/jwks.json" }`, scheme+r.Host, scheme+r.Host)
})
http.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) {
log.Printf("%s %s", r.Method, r.URL.Path)
jwkstr := fmt.Sprintf(
`{ "keys": [ { "kty": "EC" , "crv": %q , "x": %q , "y": %q , "kid": %q , "ext": true , "key_ops": ["verify"] , "exp": %s } ] }`,
jwk.Crv, jwk.X, jwk.Y, thumbprint, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10),
)
fmt.Println(jwkstr)
fmt.Fprintf(w, jwkstr)
})
fs := http.FileServer(http.Dir("public"))
http.Handle("/", fs)
/*
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
log.Printf(r.Method, r.URL.Path)
http.Error(w, "Not Found", http.StatusNotFound)
})
*/
fmt.Printf("Serving on port %d\n", port)
go func() {
log.Fatal(http.ListenAndServe(":"+strconv.Itoa(port), nil))
done <- true
}()
b, _ := json.Marshal(jwk)
fmt.Printf("Private Key:\n\t%s\n", string(b))
b, _ = json.Marshal(jwk.PublicJWK)
fmt.Printf("Public Key:\n\t%s\n", string(b))
protected, payload, token := genToken(host, priv)
fmt.Printf("Protected (Header):\n\t%s\n", protected)
fmt.Printf("Payload (Claims):\n\t%s\n", payload)
fmt.Printf("Access Token:\n\t%s\n", token)
<-done
}
func genToken(host string, priv *ecdsa.PrivateKey) (string, string, string) {
thumbprint := thumbprintKey(&priv.PublicKey)
protected := fmt.Sprintf(`{"typ":"JWT","alg":"ES256","kid":"%s"}`, thumbprint)
protected64 := base64.RawURLEncoding.EncodeToString([]byte(protected))
payload := fmt.Sprintf(
`{"iss":"%s/","sub":"dummy","exp":%s}`,
host, strconv.FormatInt(time.Now().Add(15*time.Minute).Unix(), 10),
)
payload64 := base64.RawURLEncoding.EncodeToString([]byte(payload))
hash := sha256.Sum256([]byte(fmt.Sprintf(`%s.%s`, protected64, payload64)))
r, s, _ := ecdsa.Sign(rand.Reader, priv, hash[:])
rb := r.Bytes()
for len(rb) < 32 {
rb = append([]byte{0}, rb...)
}
sb := s.Bytes()
for len(rb) < 32 {
sb = append([]byte{0}, sb...)
}
sig64 := base64.RawURLEncoding.EncodeToString(append(rb, sb...))
token := fmt.Sprintf(`%s.%s.%s`, protected64, payload64, sig64)
return protected, payload, token
}
func parseKey(jwk *PrivateJWK) *ecdsa.PrivateKey {
xb, _ := base64.RawURLEncoding.DecodeString(jwk.X)
xi := &big.Int{}
xi.SetBytes(xb)
yb, _ := base64.RawURLEncoding.DecodeString(jwk.Y)
yi := &big.Int{}
yi.SetBytes(yb)
pub := &ecdsa.PublicKey{
Curve: elliptic.P256(),
X: xi,
Y: yi,
}
db, _ := base64.RawURLEncoding.DecodeString(jwk.D)
di := &big.Int{}
di.SetBytes(db)
priv := &ecdsa.PrivateKey{
PublicKey: *pub,
D: di,
}
return priv
}
func thumbprintKey(pub *ecdsa.PublicKey) string {
minpub := []byte(fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, "P-256", pub.X, pub.Y))
sha := sha256.Sum256(minpub)
return base64.RawURLEncoding.EncodeToString(sha[:])
}