package xkeypairs

import (
	"crypto"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rsa"
	"crypto/sha256"
	"encoding/base64"
	"errors"
	"fmt"
	"io"
	"log"
	"math/big"
	mathrand "math/rand"
	"time"

	"git.rootprojects.org/root/keypairs"
)

func VerifyClaims(pubkey keypairs.PublicKey, jws *JWS) (bool, error) {
	seed, _ := jws.Header["_seed"].(int64)
	seedf64, _ := jws.Header["_seed"].(float64)
	kty, _ := jws.Header["_kty"].(string)
	kid, _ := jws.Header["kid"].(string)
	jwkmap, hasJWK := jws.Header["jwk"].(Object)
	//var jwk JWK = nil

	if 0 == seed {
		seed = int64(seedf64)
	}

	var pub keypairs.PublicKey = nil
	if hasJWK {
		log.Println("Security TODO: did not check jws.Claims[\"sub\"] against 'jwk' thumbprint")
		log.Println("Security TODO: did not check jws.Claims[\"iss\"]")
		kty := jwkmap["kty"]
		var err error
		if "RSA" == kty {
			e, _ := jwkmap["e"].(string)
			n, _ := jwkmap["n"].(string)
			k, _ := (&RSAJWK{
				Exp: e,
				N:   n,
			}).marshalJWK()
			pub, err = keypairs.ParseJWKPublicKey(k)
			if nil != err {
				return false, err
			}
		} else {
			crv, _ := jwkmap["crv"].(string)
			x, _ := jwkmap["x"].(string)
			y, _ := jwkmap["y"].(string)
			k, _ := (&ECJWK{
				Curve: crv,
				X:     x,
				Y:     y,
			}).marshalJWK()
			pub, err = keypairs.ParseJWKPublicKey(k)
			if nil != err {
				return false, err
			}
		}
	} else {
		if "" == kid {
			return false, errors.New("token should have 'kid' or 'jwk' in header")
		}
		if nil == pubkey {
			if 0 == seed {
				return false, errors.New("the debug API requires '_seed' to accompany 'kid'")
			}
			if "" == kty {
				return false, errors.New("the debug API requires '_kty' to accompany '_seed'")
			}
			privkey := genPrivKey(seed, kty)
			pub = keypairs.NewPublicKey(privkey.Public())
		} else {
			pub = pubkey
		}
		log.Println("Security TODO: did not check jws.Claims[\"kid\"] against thumbprint")
	}

	jti, _ := jws.Claims["jti"].(string)
	expf64, _ := jws.Claims["exp"].(float64)
	exp := int64(expf64)
	if 0 == exp {
		if "" == jti {
			return false, errors.New("one of 'jti' or 'exp' must exist for token expiry")
		}
	} else {
		if time.Now().Unix() > exp {
			return false, fmt.Errorf("token expired at %d (%s)", exp, time.Unix(exp, 0))
		}
	}

	signable := fmt.Sprintf("%s.%s", jws.Protected, jws.Payload)
	hash := sha256.Sum256([]byte(signable))
	sig, err := base64.RawURLEncoding.DecodeString(jws.Signature)
	if nil != err {
		return false, err
	}
	//log.Printf("\n(Verify)\nSignable: %s", signable)
	//log.Printf("Hash: %s", hash)
	//log.Printf("Sig: %s", jws.Signature)

	return Verify(pub, hash[:], sig), nil
}

func Verify(pubkey keypairs.PublicKey, hash []byte, sig []byte) bool {

	switch pub := pubkey.Key().(type) {
	case *rsa.PublicKey:
		//log.Printf("RSA VERIFY")
		// TODO keypairs.Size(key) to detect key size ?
		//alg := "SHA256"
		// TODO: this hasn't been tested yet
		if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, hash, sig); nil != err {
			return false
		}
		return true
	case *ecdsa.PublicKey:
		r := &big.Int{}
		r.SetBytes(sig[0:32])
		s := &big.Int{}
		s.SetBytes(sig[32:])
		return ecdsa.Verify(pub, hash, r, s)
	default:
		panic("impossible condition: non-rsa/non-ecdsa key")
		return false
	}
}

const maxRetry = 16

func genPrivKey(seed int64, kty string) keypairs.PrivateKey {
	var privkey keypairs.PrivateKey

	if "RSA" == kty {
		keylen := 2048
		privkey, _ = rsa.GenerateKey(nextReader(seed), keylen)
		if 0 != seed {
			for i := 0; i < maxRetry; i++ {
				otherkey, _ := rsa.GenerateKey(nextReader(seed), keylen)
				otherCmp := otherkey.D.Cmp(privkey.(*rsa.PrivateKey).D)
				if 0 != otherCmp {
					// There are two possible keys, choose the lesser D value
					// See https://github.com/square/go-jose/issues/189
					if otherCmp < 0 {
						privkey = otherkey
					}
					break
				}
				if maxRetry == i-1 {
					log.Printf("error: coinflip landed on heads %d times", maxRetry)
					// TODO return random / retry error
				}
			}
		}
	} else {
		// TODO: EC keys may also suffer the same random problems in the future
		privkey, _ = ecdsa.GenerateKey(elliptic.P256(), nextReader(seed))
	}
	return privkey
}

// this shananigans is only for testing and debug API stuff
func nextReader(seed int64) io.Reader {
	if 0 == seed {
		return RandomReader
	}
	return mathrand.New(mathrand.NewSource(seed))
}