package main import ( "compress/flate" "flag" "fmt" "log" "net/http" "os" "strings" "time" "git.coolaj86.com/coolaj86/goserv/assets" "git.coolaj86.com/coolaj86/goserv/internal/db" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" _ "github.com/joho/godotenv/autoload" ) var ( name = "goserv" version = "0.0.0" date = "0001-01-01T00:00:00Z" commit = "0000000" ) func usage() { ver() fmt.Println("") fmt.Println("Use 'help '") fmt.Println(" help") fmt.Println(" init") fmt.Println(" run") } func ver() { fmt.Printf("%s v%s %s (%s)\n", name, version, commit[:7], date) } type runOptions struct { listen string trustProxy bool compress bool static string } var runFlags *flag.FlagSet var runOpts runOptions var initFlags *flag.FlagSet var dbURL string func init() { runOpts = runOptions{} runFlags = flag.NewFlagSet("run", flag.ExitOnError) runFlags.StringVar(&runOpts.listen, "listen", ":3000", "the address and port on which to listen") runFlags.BoolVar(&runOpts.trustProxy, "trust-proxy", false, "trust X-Forwarded-For header") runFlags.BoolVar(&runOpts.compress, "compress", true, "enable compression for text,html,js,css,etc") runFlags.StringVar(&runOpts.static, "serve-path", "", "path to serve, falls back to built-in web app") runFlags.StringVar( &dbURL, "db-url", "postgres://postgres:postgres@localhost:5432/postgres", "database (postgres) connection url", ) } func main() { args := os.Args[:] if 1 == len(args) { // "run" should be the default args = append(args, "run") } if "help" == args[1] { // top-level help if 2 == len(args) { usage() os.Exit(0) return } // move help to subcommand argument self := args[0] args = append([]string{self}, args[2:]...) args = append(args, "--help") } switch args[1] { case "version": ver() os.Exit(0) return case "init": initFlags.Parse(args[2:]) case "run": runFlags.Parse(args[2:]) serve() default: usage() os.Exit(1) return } } var startedAt = time.Now() var defaultMaxBytes int64 = 1 << 20 func serve() { initDB(dbURL) r := chi.NewRouter() // A good base middleware stack if runOpts.trustProxy { r.Use(middleware.RealIP) } if runOpts.compress { r.Use(middleware.Compress(flate.DefaultCompression)) } r.Use(middleware.Logger) r.Use(middleware.Recoverer) r.Route("/api", func(r chi.Router) { r.Use(limitResponseSize) r.Use(jsonAllTheThings) r.Route("/public", func(r chi.Router) { r.Get("/status", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(fmt.Sprintf( `{ "success": true, "uptime": %.0f }%s`, time.Since(startedAt).Seconds(), "\n", ))) }) }) r.Route("/user", func(r chi.Router) { r.Get("/inspect", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(fmt.Sprintf( `{ "success": false, "error": "not implemented" }%s`, "\n", ))) }) }) }) var staticHandler http.HandlerFunc pub := http.FileServer(assets.Assets) if len(runOpts.static) > 0 { // try the user-provided directory first, then fallback to the built-in devFS := http.Dir(runOpts.static) dev := http.FileServer(devFS) staticHandler = func(w http.ResponseWriter, r *http.Request) { if _, err := devFS.Open(r.URL.Path); nil != err { pub.ServeHTTP(w, r) return } dev.ServeHTTP(w, r) } } else { staticHandler = func(w http.ResponseWriter, r *http.Request) { pub.ServeHTTP(w, r) } } r.Get("/*", staticHandler) fmt.Println("Listening for http (with reasonable timeouts) on", runOpts.listen) srv := &http.Server{ Addr: runOpts.listen, Handler: r, ReadHeaderTimeout: 2 * time.Second, ReadTimeout: 10 * time.Second, WriteTimeout: 20 * time.Second, MaxHeaderBytes: 1024 * 1024, // 1MiB } if err := srv.ListenAndServe(); nil != err { fmt.Fprintf(os.Stderr, "%s", err) os.Exit(1) return } } func jsonAllTheThings(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // just setting a default, other handlers can change this w.Header().Set("Content-Type", "application/json") next.ServeHTTP(w, r) }) } func limitResponseSize(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, defaultMaxBytes) next.ServeHTTP(w, r) }) } func initDB(connStr string) { // TODO url.Parse if strings.Contains(connStr, "@localhost/") || strings.Contains(connStr, "@localhost:") { connStr += "?sslmode=disable" } else { connStr += "?sslmode=required" } err := db.Init(connStr) if nil != err { log.Println("db connection error", err) //log.Fatal("db connection error", err) return } return }