205 lines
4.9 KiB
Go
205 lines
4.9 KiB
Go
package mockid
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.coolaj86.com/coolaj86/go-mockid/xkeypairs"
|
|
"git.rootprojects.org/root/keypairs"
|
|
//jwt "github.com/dgrijalva/jwt-go"
|
|
)
|
|
|
|
// TestMain will overwrite this
|
|
var rndsrc io.Reader = rand.Reader
|
|
|
|
type PublicJWK struct {
|
|
Crv string `json:"crv"`
|
|
KeyID string `json:"kid,omitempty"`
|
|
Kty string `json:"kty,omitempty"`
|
|
X string `json:"x"`
|
|
Y string `json:"y"`
|
|
}
|
|
|
|
type InspectableToken struct {
|
|
Public keypairs.PublicKey `json:"jwk"`
|
|
Protected map[string]interface{} `json:"protected"`
|
|
Payload map[string]interface{} `json:"payload"`
|
|
Signature string `json:"signature"`
|
|
Verified bool `json:"verified"`
|
|
Errors []string `json:"errors"`
|
|
}
|
|
|
|
func (t *InspectableToken) MarshalJSON() ([]byte, error) {
|
|
pub := keypairs.MarshalJWKPublicKey(t.Public)
|
|
header, _ := json.Marshal(t.Protected)
|
|
payload, _ := json.Marshal(t.Payload)
|
|
errs, _ := json.Marshal(t.Errors)
|
|
return []byte(fmt.Sprintf(
|
|
`{"jwk":%s,"protected":%s,"payload":%s,"signature":%q,"verified":%t,"errors":%s}`,
|
|
pub, header, payload, t.Signature, t.Verified, errs,
|
|
)), nil
|
|
}
|
|
|
|
var defaultFrom string
|
|
var defaultReplyTo string
|
|
|
|
//var nonces map[string]int64
|
|
//var nonCh chan string
|
|
var nonces sync.Map
|
|
var salt []byte
|
|
|
|
func Init() {
|
|
var err error
|
|
salt64 := os.Getenv("SALT")
|
|
salt, err = base64.RawURLEncoding.DecodeString(salt64)
|
|
if len(salt64) < 22 || nil != err {
|
|
panic("SALT must be set as 22+ character base64")
|
|
}
|
|
defaultFrom = os.Getenv("MAILER_FROM")
|
|
defaultReplyTo = os.Getenv("MAILER_REPLY_TO")
|
|
//nonces = make(map[string]int64)
|
|
//nonCh = make(chan string)
|
|
|
|
/*
|
|
go func() {
|
|
for {
|
|
nonce := <- nonCh
|
|
nonces[nonce] = time.Now().Unix()
|
|
}
|
|
}()
|
|
*/
|
|
}
|
|
|
|
func GenToken(host string, privkey keypairs.PrivateKey, query url.Values) (string, string, string) {
|
|
thumbprint := keypairs.ThumbprintPublicKey(keypairs.NewPublicKey(privkey.Public()))
|
|
// TODO keypairs.Alg(key)
|
|
alg := "ES256"
|
|
switch privkey.(type) {
|
|
case *rsa.PrivateKey:
|
|
alg = "RS256"
|
|
}
|
|
protected := fmt.Sprintf(`{"typ":"JWT","alg":%q,"kid":"%s"}`, alg, thumbprint)
|
|
protected64 := base64.RawURLEncoding.EncodeToString([]byte(protected))
|
|
|
|
exp, err := xkeypairs.ParseDuration(query.Get("exp"))
|
|
if nil != err {
|
|
// cryptic error code
|
|
// TODO propagate error
|
|
exp = 422
|
|
}
|
|
|
|
payload := fmt.Sprintf(
|
|
`{"iss":"%s/","sub":"dummy","exp":%s}`,
|
|
host, strconv.FormatInt(time.Now().Add(time.Duration(exp)*time.Second).Unix(), 10),
|
|
)
|
|
payload64 := base64.RawURLEncoding.EncodeToString([]byte(payload))
|
|
|
|
hash := sha256.Sum256([]byte(fmt.Sprintf(`%s.%s`, protected64, payload64)))
|
|
sig := JOSESign(privkey, hash[:])
|
|
sig64 := base64.RawURLEncoding.EncodeToString(sig)
|
|
token := fmt.Sprintf("%s.%s.%s\n", protected64, payload64, sig64)
|
|
return protected, payload, token
|
|
}
|
|
|
|
// TODO: move to keypairs
|
|
|
|
func JOSEVerify(pubkey keypairs.PublicKey, hash []byte, sig []byte) bool {
|
|
var verified bool
|
|
|
|
switch pub := pubkey.Key().(type) {
|
|
case *rsa.PublicKey:
|
|
// TODO keypairs.Size(key) to detect key size ?
|
|
//alg := "SHA256"
|
|
// TODO: this hasn't been tested yet
|
|
if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, hash, sig); nil != err {
|
|
verified = true
|
|
}
|
|
case *ecdsa.PublicKey:
|
|
r := &big.Int{}
|
|
r.SetBytes(sig[0:32])
|
|
s := &big.Int{}
|
|
s.SetBytes(sig[32:])
|
|
fmt.Println("debug: sig len:", len(sig))
|
|
fmt.Println("debug: r, s:", r, s)
|
|
verified = ecdsa.Verify(pub, hash, r, s)
|
|
default:
|
|
panic("impossible condition: non-rsa/non-ecdsa key")
|
|
}
|
|
|
|
return verified
|
|
}
|
|
|
|
func JOSESign(privkey keypairs.PrivateKey, hash []byte) []byte {
|
|
var sig []byte
|
|
|
|
switch k := privkey.(type) {
|
|
case *rsa.PrivateKey:
|
|
panic("TODO: implement rsa sign")
|
|
case *ecdsa.PrivateKey:
|
|
r, s, _ := ecdsa.Sign(rndsrc, k, hash[:])
|
|
rb := r.Bytes()
|
|
fmt.Println("debug:")
|
|
fmt.Println(r, s)
|
|
for len(rb) < 32 {
|
|
rb = append([]byte{0}, rb...)
|
|
}
|
|
sb := s.Bytes()
|
|
for len(rb) < 32 {
|
|
sb = append([]byte{0}, sb...)
|
|
}
|
|
sig = append(rb, sb...)
|
|
}
|
|
return sig
|
|
}
|
|
|
|
func issueNonce(w http.ResponseWriter, r *http.Request) {
|
|
b := make([]byte, 16)
|
|
_, _ = rand.Read(b)
|
|
nonce := base64.RawURLEncoding.EncodeToString(b)
|
|
//nonCh <- nonce
|
|
nonces.Store(nonce, time.Now())
|
|
|
|
w.Header().Set("Replay-Nonce", nonce)
|
|
}
|
|
|
|
func requireNonce(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
nonce := r.Header.Get("Replay-Nonce")
|
|
// TODO expire nonces every so often
|
|
//t := nonces[nonce]
|
|
var t time.Time
|
|
tmp, ok := nonces.Load(nonce)
|
|
if ok {
|
|
t = tmp.(time.Time)
|
|
}
|
|
if !ok || time.Now().Sub(t) > 15*time.Minute {
|
|
http.Error(
|
|
w,
|
|
`{ "error": "invalid or expired nonce", "error_code": "ENONCE" }`,
|
|
http.StatusBadRequest,
|
|
)
|
|
return
|
|
}
|
|
|
|
//delete(nonces, nonce)
|
|
nonces.Delete(nonce)
|
|
issueNonce(w, r)
|
|
|
|
next(w, r)
|
|
}
|
|
}
|