package api

import (
	"encoding/base64"
	"encoding/json"
	"io"
	"log"
	"net/http"
	"strings"
	"time"

	"git.coolaj86.com/coolaj86/go-mockid/xkeypairs"
)

// Verify will verify both JWT and uncompressed JWS
func Verify(w http.ResponseWriter, r *http.Request) {
	if "POST" != r.Method {
		http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
		return
	}

	jws := &xkeypairs.JWS{}

	authzParts := strings.Split(r.Header.Get("Authorization"), " ")
	lenAuthz := len(authzParts)
	if 2 == lenAuthz {
		jwt := authzParts[1]
		jwsParts := strings.Split(jwt, ".")
		if 3 == len(jwsParts) {
			jws.Protected = jwsParts[0]
			jws.Payload = jwsParts[1]
			jws.Signature = jwsParts[2]
		}
	}

	if nil == jws {
		if 0 != lenAuthz {
			http.Error(w, "Bad Request: malformed Authorization header", http.StatusBadRequest)
			return
		}
		decoder := json.NewDecoder(r.Body)
		err := decoder.Decode(jws)
		if nil != err && io.EOF != err {
			log.Printf("json decode error: %s", err)
			http.Error(w, "Bad Request: invalid JWS body", http.StatusBadRequest)
			return
		}
		defer r.Body.Close()
	}

	protected, err := base64.RawURLEncoding.DecodeString(jws.Protected)
	if nil != err {
		http.Error(w, "Bad Request: invalid JWS header base64Url encoding", http.StatusBadRequest)
		return
	}
	if err := json.Unmarshal([]byte(protected), &jws.Header); nil != err {
		log.Printf("json decode header error: %s", err)
		http.Error(w, "Bad Request: invalid JWS header", http.StatusBadRequest)
		return
	}

	payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
	if nil != err {
		http.Error(w, "Bad Request: invalid JWS payload base64Url encoding", http.StatusBadRequest)
		return
	}
	if err := json.Unmarshal([]byte(payload), &jws.Claims); nil != err {
		log.Printf("json decode claims error: %s", err)
		http.Error(w, "Bad Request: invalid JWS claims", http.StatusBadRequest)
		return
	}

	if "false" == r.URL.Query().Get("exp") {
		//expf64, _ := jws.Claims["exp"].(float64)
		jws.Claims["exp"] = float64(time.Now().Add(5 * time.Minute).Unix())
	}

	ok, err := xkeypairs.VerifyClaims(nil, jws)
	if nil != err {
		log.Printf("jws verify error: %s", err)
		http.Error(w, "Bad Request: could not verify JWS claims", http.StatusBadRequest)
		return
	}
	if !ok {
		http.Error(w, "Bad Request: invalid JWS signature", http.StatusBadRequest)
		return
	}

	b := []byte(`{"success":true}`)
	w.Write(append(b, '\n'))
}