From 151c3fea88f454c3cde290e9e4ea6f3ea6da59e3 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Tue, 14 Jul 2015 00:14:48 -0600 Subject: [PATCH] dynamic SNI loading --- serve.go | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 83 insertions(+), 5 deletions(-) diff --git a/serve.go b/serve.go index 58afce1..9d0fc1b 100644 --- a/serve.go +++ b/serve.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strconv" "strings" + "time" ) func usage() { @@ -20,7 +21,14 @@ func usage() { os.Exit(2) } -type myHandler struct{} +type myHandler struct { + certMap map[string]tls.Certificate +} + +type myCert struct { + cert *tls.Certificate + touchedAt time.Time +} func (m *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Print debug info @@ -31,9 +39,13 @@ func (m *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { for k, v := range r.Header { fmt.Println(k, v) } + fmt.Println(r.Body) + fmt.Println() + fmt.Println() // End the request + // TODO serve from hosting directory fmt.Fprintf(w, "Hi there, %s %q? Wow!\n\nWith Love,\n\t%s", r.Method, r.URL.Path[1:], r.Host) } @@ -58,7 +70,7 @@ func main() { fmt.Printf("Loading Certificates %s/%s/{privkey.pem,fullchain.pem}\n", *certsPath, *defaultHost) privkeyPath := filepath.Join(*certsPath, *defaultHost, "privkey.pem") certPath := filepath.Join(*certsPath, *defaultHost, "fullchain.pem") - cert, err := tls.LoadX509KeyPair(certPath, privkeyPath) + defaultCert, err := tls.LoadX509KeyPair(certPath, privkeyPath) if err != nil { fmt.Fprintf(os.Stderr, "Couldn't load default certificates: %s\n", err) os.Exit(1) @@ -72,10 +84,76 @@ func main() { os.Exit(1) } + certMap := make(map[string]myCert) tlsConfig := new(tls.Config) - tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.Certificates = []tls.Certificate{defaultCert} tlsConfig.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - return &cert, nil + + // Load from memory + // TODO unload untouched certificates every x minutes + if myCert, ok := certMap[clientHello.ServerName]; ok { + myCert.touchedAt = time.Now() + return myCert.cert, nil + } + + privkeyPath := filepath.Join(*certsPath, clientHello.ServerName, "privkey.pem") + certPath := filepath.Join(*certsPath, clientHello.ServerName, "fullchain.pem") + + loadCert := func() *tls.Certificate { + // TODO handle race condition (ask Matt) + // the transaction is idempotent, however, so it shouldn't matter + if _, err := os.Stat(privkeyPath); err == nil { + fmt.Printf("Loading Certificates %s/%s/{privkey.pem,fullchain.pem}\n\n", *certsPath, clientHello.ServerName) + cert, err := tls.LoadX509KeyPair(certPath, privkeyPath) + if nil != err { + return &cert + } + return nil + } + + return nil + } + + if cert := loadCert(); nil != cert { + certMap[clientHello.ServerName] = myCert{ + cert: cert, + touchedAt: time.Now(), + } + return cert, nil + } + + // TODO try to get cert via letsencrypt python client + // TODO check for a hosting directory before attempting this + /* + cmd := exec.Command( + "./venv/bin/letsencrypt", + "--text", + "--agree-eula", + "--email", "coolaj86@gmail.com", + "--authenticator", "standalone", + "--domains", "www.example.com", "example.com", + "--dvsni-port", "65443", + "auth", + ) + err := cmd.Run() + if nil != err { + if cert := loadCert(); nil != cert { + return cert, nil + } + } + */ + + fmt.Fprintf(os.Stderr, "Failed to load certificates for %q.\n", clientHello.ServerName) + fmt.Fprintf(os.Stderr, "\tTried %s/{privkey.pem,fullchain.pem}\n", filepath.Join(*certsPath, clientHello.ServerName)) + //fmt.Fprintf(os.Stderr, "\tand letsencrypt api\n") + fmt.Fprintf(os.Stderr, "\n") + // TODO how to prevent attack and still enable retry? + // perhaps check DNS and hosting directory, wait 5 minutes? + certMap[clientHello.ServerName] = myCert{ + cert: &defaultCert, + touchedAt: time.Now(), + } + return &defaultCert, nil } tlsListener := tls.NewListener(conn, tlsConfig) @@ -83,6 +161,6 @@ func main() { Addr: addr, Handler: &myHandler{}, } - fmt.Printf("Listening on https://%s:%d\n", host, *port) + fmt.Printf("Listening on https://%s:%d\n\n", host, *port) server.Serve(tlsListener) }