package xkeypairs

import (
	"crypto"
	"crypto/ecdsa"
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha256"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	mathrand "math/rand"
	"time"

	"git.rootprojects.org/root/keypairs"
)

// RandomReader may be overwritten for testing
var RandomReader io.Reader = rand.Reader

//var RandomReader = rand.Reader

type JWS struct {
	Header    Object `json:"header"`    // JSON
	Claims    Object `json:"claims"`    // JSON
	Protected string `json:"protected"` // base64
	Payload   string `json:"payload"`   // base64
	Signature string `json:"signature"` // base64
}

type Object = map[string]interface{}

// SignClaims adds `typ`, `kid` (or `jwk`), and `alg` in the header and expects claims for `jti`, `exp`, `iss`, and `iat`
func SignClaims(privkey keypairs.PrivateKey, header Object, claims Object) (*JWS, error) {
	var randsrc io.Reader = RandomReader
	seed, _ := header["_seed"].(int64)
	if 0 != seed {
		randsrc = mathrand.New(mathrand.NewSource(seed))
		//delete(header, "_seed")
	}

	protected, err := headerToProtected(keypairs.NewPublicKey(privkey.Public()), header)
	if nil != err {
		return nil, err
	}
	protected64 := base64.RawURLEncoding.EncodeToString(protected)

	payload, err := claimsToPayload(claims)
	if nil != err {
		return nil, err
	}
	payload64 := base64.RawURLEncoding.EncodeToString(payload)

	signable := fmt.Sprintf(`%s.%s`, protected64, payload64)
	hash := sha256.Sum256([]byte(signable))

	sig := Sign(randsrc, privkey, hash[:])
	sig64 := base64.RawURLEncoding.EncodeToString(sig)
	//log.Printf("\n(Sign)\nSignable: %s", signable)
	//log.Printf("Hash: %s", hash)
	//log.Printf("Sig: %s", sig64)

	return &JWS{
		Header:    header,
		Claims:    claims,
		Protected: protected64,
		Payload:   payload64,
		Signature: sig64,
	}, nil
}

func headerToProtected(pub keypairs.PublicKey, header Object) ([]byte, error) {
	if nil == header {
		header = Object{}
	}

	// Only supporting 2048-bit and P256 keys right now
	// because that's all that's practical and well-supported.
	// No security theatre here.
	alg := "ES256"
	switch pub.Key().(type) {
	case *rsa.PublicKey:
		alg = "RS256"
	}

	if selfSign, _ := header["_jwk"].(bool); selfSign {
		delete(header, "_jwk")
		any := Object{}
		_ = json.Unmarshal(keypairs.MarshalJWKPublicKey(pub), &any)
		header["jwk"] = any
	}

	// TODO what are the acceptable values? JWT. JWS? others?
	header["typ"] = "JWT"
	if _, ok := header["jwk"]; !ok {
		thumbprint := keypairs.ThumbprintPublicKey(pub)
		kid, _ := header["kid"].(string)
		if "" != kid && thumbprint != kid {
			return nil, errors.New("'kid' should be the key's thumbprint")
		}
		header["kid"] = thumbprint
	}
	header["alg"] = alg

	protected, err := json.Marshal(header)
	if nil != err {
		return nil, err
	}
	return protected, nil
}

func claimsToPayload(claims Object) ([]byte, error) {
	if nil == claims {
		claims = Object{}
	}

	jti, _ := claims["jti"].(string)
	exp, _ := claims["exp"].(int64)
	dur, _ := claims["exp"].(string)
	insecure, _ := claims["insecure"].(bool)

	// parse if exp is actually a duration, such as "15m"
	if 0 == exp && "" != dur {
		s, err := time.ParseDuration(dur)
		// TODO s, err := time.ParseDuration(dur)
		if nil != err {
			return nil, err
		}
		exp = time.Now().Add(s * time.Second).Unix()
		claims["exp"] = exp
	}
	if "" == jti && 0 == exp && !insecure {
		return nil, errors.New("token must have jti or exp as to be expirable / cancellable")
	}

	return json.Marshal(claims)
}

func JWSToJWT(jwt *JWS) string {
	return fmt.Sprintf(
		"%s.%s.%s",
		jwt.Protected,
		jwt.Payload,
		jwt.Signature,
	)
}

func Sign(rand io.Reader, privkey keypairs.PrivateKey, hash []byte) []byte {
	var sig []byte

	if len(hash) != 32 {
		panic("only 256-bit hashes for 2048-bit and 256-bit keys are supported")
	}

	switch k := privkey.(type) {
	case *rsa.PrivateKey:
		sig, _ = rsa.SignPKCS1v15(rand, k, crypto.SHA256, hash)
	case *ecdsa.PrivateKey:
		r, s, _ := ecdsa.Sign(rand, k, hash[:])
		rb := r.Bytes()
		for len(rb) < 32 {
			rb = append([]byte{0}, rb...)
		}
		sb := s.Bytes()
		for len(rb) < 32 {
			sb = append([]byte{0}, sb...)
		}
		sig = append(rb, sb...)
	}
	return sig
}