Compare commits

...

2 Commits

Author SHA1 Message Date
AJ ONeal 557f9085f6 add kv fs store and tests 2020-09-16 22:34:25 +00:00
AJ ONeal 9de2f796db cleanup some xkeypairs functions 2020-09-16 22:32:46 +00:00
7 changed files with 282 additions and 107 deletions

131
kvdb/kvdb.go Normal file
View File

@ -0,0 +1,131 @@
package kvdb
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strconv"
"strings"
)
type KVDB struct {
Prefix string
Ext string
}
func (kv *KVDB) Load(
keyif interface{},
typ ...interface{},
) (value interface{}, ok bool, err error) {
key, _ := keyif.(string)
if "" == key || strings.Contains(key, "..") || strings.ContainsAny(key, "$#!:| \n") {
return nil, false, nil
}
userFile := filepath.Join(kv.Prefix, key+"."+kv.Ext)
fmt.Println("Debug user file:", userFile)
b, err := ioutil.ReadFile(userFile)
if nil != err {
if os.IsNotExist(err) {
return nil, false, nil
}
fmt.Println("kvdb debug read:", err)
return nil, false, errors.New("database read failed")
}
ok = true
value = b
if 1 == len(typ) {
err := json.Unmarshal(b, typ[0])
if nil != err {
return nil, false, err
}
value = typ[0]
} else if len(b) > 0 && '"' == b[0] {
var str string
err := json.Unmarshal(b, &str)
if nil == err {
value = str
}
}
return value, ok, nil
}
func (kv *KVDB) Store(keyif interface{}, value interface{}) (err error) {
key, _ := keyif.(string)
if "" == key || strings.Contains(key, "..") || strings.ContainsAny(key, "$#! \n") {
return errors.New("invalid key name")
}
keypath := filepath.Join(kv.Prefix, key+"."+kv.Ext)
f, err := os.Open(keypath)
if nil == err {
s, err := f.Stat()
if nil != err {
// if we can open, we should be able to stat
return errors.New("database connection failure")
}
ts := strconv.FormatInt(s.ModTime().Unix(), 10)
bakpath := filepath.Join(kv.Prefix, key+"."+ts+"."+kv.Ext)
if err := os.Rename(keypath, bakpath); nil != err {
// keep the old record as a backup
return errors.New("database write failure")
}
}
var b []byte
switch v := value.(type) {
case []byte:
b = v
case string:
b, _ = json.Marshal(v)
default:
fmt.Println("kvdb: not []byte or string:", v)
jsonb, err := json.Marshal(v)
if nil != err {
return err
}
b = jsonb
}
if err := ioutil.WriteFile(
keypath,
b,
os.FileMode(0600),
); nil != err {
fmt.Println("write failure:", err)
return errors.New("database write failed")
}
return nil
}
func (kv *KVDB) Delete(keyif interface{}) (err error) {
key, _ := keyif.(string)
if "" == key || strings.Contains(key, "..") || strings.ContainsAny(key, "$#! \n") {
return errors.New("invalid key name")
}
keypath := filepath.Join(kv.Prefix, key+"."+kv.Ext)
f, err := os.Open(keypath)
if nil == err {
s, err := f.Stat()
if nil != err {
return errors.New("database connection failure")
}
ts := strconv.FormatInt(s.ModTime().Unix(), 64)
if err := os.Rename(keypath, filepath.Join(kv.Prefix, key+"."+ts+"."+kv.Ext)); nil != err {
return errors.New("database connection failure")
}
}
return nil
}
func (kv *KVDB) Vacuum() (err error) {
return nil
}

63
kvdb/kvdb_test.go Normal file
View File

@ -0,0 +1,63 @@
package kvdb
import (
"strings"
"testing"
)
type TestEntry struct {
Email string `json:"email"`
Subjects []string `json:"subjects"`
}
var email = "john@example.com"
var sub = "id123"
var dbPrefix = "../testdb"
var testKV = &KVDB{
Prefix: dbPrefix + "/test-entries",
Ext: "eml.json",
}
func TestStore(t *testing.T) {
entry := &TestEntry{
Email: email,
Subjects: []string{sub},
}
if err := testKV.Store(email, entry); nil != err {
t.Fatal(err)
return
}
value, ok, err := testKV.Load(email, &(TestEntry{}))
if nil != err {
t.Fatal(err)
return
}
if !ok {
t.Fatal("test entry not found")
}
v, ok := value.(*TestEntry)
if !ok {
t.Fatal("test entry not of type TestEntry")
}
if email != v.Email || sub != strings.Join(v.Subjects, ",") {
t.Fatalf("value: %#v", v)
}
}
func TestNoExist(t *testing.T) {
value, ok, err := testKV.Load("not"+email, &(TestEntry{}))
if nil != err {
t.Fatal(err)
return
}
if ok {
t.Fatal("found entry that doesn't exist")
}
if value != nil {
t.Fatal("had value for entry that doesn't exist")
}
}

View File

@ -11,35 +11,17 @@ import (
"math/rand" "math/rand"
mathrand "math/rand" mathrand "math/rand"
"net/http" "net/http"
"git.coolaj86.com/coolaj86/go-mockid/xkeypairs"
) )
type Object = map[string]interface{}
// options are the things that we may need to know about a request to fulfill it properly
type options struct {
Key string `json:"key"`
KeyType string `json:"kty"`
Seed int64 `json:"-"`
SeedStr string `json:"seed"`
Claims Object `json:"claims"`
Header Object `json:"header"`
}
// this shananigans is only for testing and debug API stuff
func (o *options) nextReader() io.Reader {
if 0 == o.Seed {
return RandomReader
}
return rand.New(rand.NewSource(o.Seed))
}
/* /*
func getJWS(r *http.Request) (*options, error) { func getJWS(r *http.Request) (*xkeypairs.KeyOptions, error) {
} }
*/ */
func getOpts(r *http.Request) (*options, error) { func getOpts(r *http.Request) (*xkeypairs.KeyOptions, error) {
tok := make(map[string]interface{}) tok := make(map[string]interface{})
decoder := json.NewDecoder(r.Body) decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&tok) err := decoder.Decode(&tok)
@ -60,17 +42,17 @@ func getOpts(r *http.Request) (*options, error) {
} }
key, _ := tok["key"].(string) key, _ := tok["key"].(string)
opts := &options{ opts := &xkeypairs.KeyOptions{
Seed: seed, Seed: seed,
Key: key, Key: key,
} }
opts.Claims, _ = tok["claims"].(Object) opts.Claims, _ = tok["claims"].(xkeypairs.Object)
opts.Header, _ = tok["header"].(Object) opts.Header, _ = tok["header"].(xkeypairs.Object)
var n int var n int
if 0 != seed { if 0 != seed {
n = opts.nextReader().(*mathrand.Rand).Intn(2) n = opts.MyFooNextReader().(*mathrand.Rand).Intn(2)
} else { } else {
n = rand.Intn(2) n = rand.Intn(2)
} }

View File

@ -1,21 +1,12 @@
package api package api
import ( import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"io"
"log"
"net/http" "net/http"
"git.coolaj86.com/coolaj86/go-mockid/xkeypairs" "git.coolaj86.com/coolaj86/go-mockid/xkeypairs"
"git.rootprojects.org/root/keypairs" "git.rootprojects.org/root/keypairs"
) )
// RandomReader may be overwritten for testing
var RandomReader io.Reader = rand.Reader
// GeneratePublicJWK will create a new private key in JWK format // GeneratePublicJWK will create a new private key in JWK format
func GeneratePublicJWK(w http.ResponseWriter, r *http.Request) { func GeneratePublicJWK(w http.ResponseWriter, r *http.Request) {
if "POST" != r.Method { if "POST" != r.Method {
@ -52,7 +43,7 @@ func GeneratePrivateJWK(w http.ResponseWriter, r *http.Request) {
return return
} }
privkey := genPrivKey(opts) privkey := xkeypairs.GenPrivKey(opts)
jwk := xkeypairs.MarshalJWKPrivateKey(privkey) jwk := xkeypairs.MarshalJWKPrivateKey(privkey)
w.Write(append(jwk, '\n')) w.Write(append(jwk, '\n'))
@ -95,7 +86,7 @@ func GeneratePrivateDER(w http.ResponseWriter, r *http.Request) {
return return
} }
privkey := genPrivKey(opts) privkey := xkeypairs.GenPrivKey(opts)
der, _ := xkeypairs.MarshalDERPrivateKey(privkey) der, _ := xkeypairs.MarshalDERPrivateKey(privkey)
w.Write(der) w.Write(der)
@ -138,7 +129,7 @@ func GeneratePrivatePEM(w http.ResponseWriter, r *http.Request) {
return return
} }
privkey := genPrivKey(opts) privkey := xkeypairs.GenPrivKey(opts)
privpem, _ := xkeypairs.MarshalPEMPrivateKey(privkey) privpem, _ := xkeypairs.MarshalPEMPrivateKey(privkey)
w.Write(privpem) w.Write(privpem)
@ -146,39 +137,9 @@ func GeneratePrivatePEM(w http.ResponseWriter, r *http.Request) {
const maxRetry = 16 const maxRetry = 16
func getPrivKey(opts *options) (keypairs.PrivateKey, error) { func getPrivKey(opts *xkeypairs.KeyOptions) (keypairs.PrivateKey, error) {
if "" != opts.Key { if "" != opts.Key {
return keypairs.ParsePrivateKey([]byte(opts.Key)) return keypairs.ParsePrivateKey([]byte(opts.Key))
} }
return genPrivKey(opts), nil return xkeypairs.GenPrivKey(opts), nil
}
func genPrivKey(opts *options) keypairs.PrivateKey {
var privkey keypairs.PrivateKey
if "RSA" == opts.KeyType {
keylen := 2048
privkey, _ = rsa.GenerateKey(opts.nextReader(), keylen)
if 0 != opts.Seed {
for i := 0; i < maxRetry; i++ {
otherkey, _ := rsa.GenerateKey(opts.nextReader(), keylen)
otherCmp := otherkey.D.Cmp(privkey.(*rsa.PrivateKey).D)
if 0 != otherCmp {
// There are two possible keys, choose the lesser D value
// See https://github.com/square/go-jose/issues/189
if otherCmp < 0 {
privkey = otherkey
}
break
}
if maxRetry == i-1 {
log.Printf("error: coinflip landed on heads %d times", maxRetry)
}
}
}
} else {
// TODO: EC keys may also suffer the same random problems in the future
privkey, _ = ecdsa.GenerateKey(elliptic.P256(), opts.nextReader())
}
return privkey
} }

61
xkeypairs/generate.go Normal file
View File

@ -0,0 +1,61 @@
package xkeypairs
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"io"
"log"
"math/rand"
"git.rootprojects.org/root/keypairs"
)
// KeyOptions are the things that we may need to know about a request to fulfill it properly
type KeyOptions struct {
Key string `json:"key"`
KeyType string `json:"kty"`
Seed int64 `json:"-"`
SeedStr string `json:"seed"`
Claims Object `json:"claims"`
Header Object `json:"header"`
}
// this shananigans is only for testing and debug API stuff
func (o *KeyOptions) MyFooNextReader() io.Reader {
if 0 == o.Seed {
return RandomReader
}
return rand.New(rand.NewSource(o.Seed))
}
// GenPrivKey generates a 256-bit entropy RSA or ECDSA private key
func GenPrivKey(opts *KeyOptions) keypairs.PrivateKey {
var privkey keypairs.PrivateKey
if "RSA" == opts.KeyType {
keylen := 2048
privkey, _ = rsa.GenerateKey(opts.MyFooNextReader(), keylen)
if 0 != opts.Seed {
for i := 0; i < maxRetry; i++ {
otherkey, _ := rsa.GenerateKey(opts.MyFooNextReader(), keylen)
otherCmp := otherkey.D.Cmp(privkey.(*rsa.PrivateKey).D)
if 0 != otherCmp {
// There are two possible keys, choose the lesser D value
// See https://github.com/square/go-jose/issues/189
if otherCmp < 0 {
privkey = otherkey
}
break
}
if maxRetry == i-1 {
log.Printf("error: coinflip landed on heads %d times", maxRetry)
}
}
}
} else {
// TODO: EC keys may also suffer the same random problems in the future
privkey, _ = ecdsa.GenerateKey(elliptic.P256(), opts.MyFooNextReader())
}
return privkey
}

View File

@ -1,7 +1,7 @@
package xkeypairs package xkeypairs
import ( import (
"strconv" "io/ioutil"
"git.rootprojects.org/root/keypairs" "git.rootprojects.org/root/keypairs"
) )
@ -12,37 +12,11 @@ func ParsePEMPrivateKey(block []byte) (keypairs.PrivateKey, error) {
return keypairs.ParsePrivateKey(block) return keypairs.ParsePrivateKey(block)
} }
func ParseDuration(exp string) (int, error) { // ParsePrivateKeyFile returns the private key from the given file path, if available
if "" == exp { func ParsePrivateKeyFile(pathname string) (keypairs.PrivateKey, error) {
exp = "15m" block, err := ioutil.ReadFile(pathname)
}
mult := 1
switch exp[len(exp)-1] {
case 'w':
mult *= 7
fallthrough
case 'd':
mult *= 24
fallthrough
case 'h':
mult *= 60
fallthrough
case 'm':
mult *= 60
fallthrough
case 's':
// no fallthrough
default:
// could be 'k' or 'z', but we assume its empty
exp += "s"
}
// 15m => num=15, mult=1*60
num, err := strconv.Atoi(exp[:len(exp)-1])
if nil != err { if nil != err {
return 0, err return nil, err
} }
return keypairs.ParsePrivateKey(block)
return num * mult, nil
} }

View File

@ -17,7 +17,10 @@ import (
"git.rootprojects.org/root/keypairs" "git.rootprojects.org/root/keypairs"
) )
var RandomReader = rand.Reader // RandomReader may be overwritten for testing
var RandomReader io.Reader = rand.Reader
//var RandomReader = rand.Reader
type JWS struct { type JWS struct {
Header Object `json:"header"` // JSON Header Object `json:"header"` // JSON
@ -29,6 +32,7 @@ type JWS struct {
type Object = map[string]interface{} type Object = map[string]interface{}
// SignClaims adds `typ`, `kid` (or `jwk`), and `alg` in the header and expects claims for `jti`, `exp`, `iss`, and `iat`
func SignClaims(privkey keypairs.PrivateKey, header Object, claims Object) (*JWS, error) { func SignClaims(privkey keypairs.PrivateKey, header Object, claims Object) (*JWS, error) {
var randsrc io.Reader = RandomReader var randsrc io.Reader = RandomReader
seed, _ := header["_seed"].(int64) seed, _ := header["_seed"].(int64)
@ -119,11 +123,12 @@ func claimsToPayload(claims Object) ([]byte, error) {
// parse if exp is actually a duration, such as "15m" // parse if exp is actually a duration, such as "15m"
if 0 == exp && "" != dur { if 0 == exp && "" != dur {
s, err := ParseDuration(dur) s, err := time.ParseDuration(dur)
// TODO s, err := time.ParseDuration(dur)
if nil != err { if nil != err {
return nil, err return nil, err
} }
exp = time.Now().Add(time.Duration(s) * time.Second).Unix() exp = time.Now().Add(s * time.Second).Unix()
claims["exp"] = exp claims["exp"] = exp
} }
if "" == jti && 0 == exp && !insecure { if "" == jti && 0 == exp && !insecure {
@ -155,8 +160,6 @@ func Sign(rand io.Reader, privkey keypairs.PrivateKey, hash []byte) []byte {
case *ecdsa.PrivateKey: case *ecdsa.PrivateKey:
r, s, _ := ecdsa.Sign(rand, k, hash[:]) r, s, _ := ecdsa.Sign(rand, k, hash[:])
rb := r.Bytes() rb := r.Bytes()
fmt.Println("debug:")
fmt.Println(r, s)
for len(rb) < 32 { for len(rb) < 32 {
rb = append([]byte{0}, rb...) rb = append([]byte{0}, rb...)
} }