490 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			490 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package main
 | 
						|
 | 
						|
// TODO learn more about chan chan's
 | 
						|
// http://marcio.io/2015/07/handling-1-million-requests-per-minute-with-golang/
 | 
						|
 | 
						|
import (
 | 
						|
	"bufio"
 | 
						|
	"bytes"
 | 
						|
	"crypto/rand"
 | 
						|
	"encoding/base64"
 | 
						|
	"flag"
 | 
						|
	"fmt"
 | 
						|
	"io/ioutil"
 | 
						|
	"net"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"os"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/emicklei/go-restful"
 | 
						|
	"gopkg.in/yaml.v2"
 | 
						|
)
 | 
						|
 | 
						|
// I'm not sure how to pass nested structs, so I de-nested this.
 | 
						|
// TODO Learn if passing nested structs is desirable?
 | 
						|
type Conf struct {
 | 
						|
	Addr     string `yaml:"addr,omitempty"`
 | 
						|
	Port     uint   `yaml:"port,omitempty"`
 | 
						|
	Mailer   ConfMailer
 | 
						|
	RootPath string `yaml:"root_path,omitempty"`
 | 
						|
}
 | 
						|
type ConfMailer struct {
 | 
						|
	Url    string `yaml:"url,omitempty"`
 | 
						|
	ApiKey string `yaml:"api_key,omitempty"`
 | 
						|
	From   string `yaml:"from,omitempty"`
 | 
						|
}
 | 
						|
 | 
						|
// So we can peek at net.Conn, which we can't do natively
 | 
						|
// https://stackoverflow.com/questions/51472020/how-to-get-the-size-of-available-tcp-data
 | 
						|
type bufferedConn struct {
 | 
						|
	r *bufio.Reader
 | 
						|
	//rout *io.Reader // See https://github.com/polvi/sni/blob/master/sni.go#L135
 | 
						|
	net.Conn
 | 
						|
}
 | 
						|
 | 
						|
func newBufferedConn(c net.Conn) bufferedConn {
 | 
						|
	return bufferedConn{bufio.NewReader(c), c}
 | 
						|
}
 | 
						|
 | 
						|
func (b bufferedConn) Peek(n int) ([]byte, error) {
 | 
						|
	return b.r.Peek(n)
 | 
						|
}
 | 
						|
 | 
						|
func (b bufferedConn) Buffered() int {
 | 
						|
	return b.r.Buffered()
 | 
						|
}
 | 
						|
 | 
						|
func (b bufferedConn) Read(p []byte) (int, error) {
 | 
						|
	/*
 | 
						|
		if b.rout != nil {
 | 
						|
			return b.rout.Read(p)
 | 
						|
		}
 | 
						|
	*/
 | 
						|
	return b.r.Read(p)
 | 
						|
}
 | 
						|
 | 
						|
type chatMsg struct {
 | 
						|
	sender     net.Conn
 | 
						|
	Message    string    `json:"message"`
 | 
						|
	ReceivedAt time.Time `json:"received_at"`
 | 
						|
	Channel    string    `json:"channel"`
 | 
						|
	User       string    `json:"user"`
 | 
						|
}
 | 
						|
 | 
						|
// Poor-Man's container/ring (circular buffer)
 | 
						|
type chatHist struct {
 | 
						|
	msgs []*chatMsg
 | 
						|
	i    int // current index
 | 
						|
	c    int // current count (number of elements)
 | 
						|
}
 | 
						|
 | 
						|
// Multi-use
 | 
						|
var config Conf
 | 
						|
var virginConns chan net.Conn
 | 
						|
var gotClientHello chan bufferedConn
 | 
						|
var myChatHist chatHist
 | 
						|
var broadcastMsg chan chatMsg
 | 
						|
 | 
						|
// Telnet
 | 
						|
var wantsServerHello chan bufferedConn
 | 
						|
var authTelnet chan telnetUser
 | 
						|
var cleanTelnet chan telnetUser // intentionally blocking
 | 
						|
 | 
						|
// HTTP
 | 
						|
var demuxHttpClient chan bufferedConn
 | 
						|
var authReqs chan authReq
 | 
						|
var valAuthReqs chan authReq
 | 
						|
var delAuthReqs chan authReq
 | 
						|
 | 
						|
func usage() {
 | 
						|
	fmt.Fprintf(os.Stderr, "\nusage: go run chatserver*.go\n")
 | 
						|
	flag.PrintDefaults()
 | 
						|
	fmt.Println()
 | 
						|
 | 
						|
	os.Exit(1)
 | 
						|
}
 | 
						|
 | 
						|
// https://blog.questionable.services/article/generating-secure-random-numbers-crypto-rand/
 | 
						|
func genAuthCode() (string, error) {
 | 
						|
	n := 12
 | 
						|
	b := make([]byte, n)
 | 
						|
	_, err := rand.Read(b)
 | 
						|
	// Note that err == nil only if we read len(b) bytes.
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	return base64.URLEncoding.EncodeToString(b), nil
 | 
						|
}
 | 
						|
 | 
						|
func muxTcp(conn bufferedConn) {
 | 
						|
	// Wish List for protocol detection
 | 
						|
	// * PROXY protocol (and loop)
 | 
						|
	// * HTTP CONNECT (proxy) (and loop)
 | 
						|
	// * tls (and loop) https://github.com/polvi/sni
 | 
						|
	// * http/ws
 | 
						|
	// * irc
 | 
						|
	// * fallback to telnet
 | 
						|
 | 
						|
	// At this piont we've already at least one byte via Peek()
 | 
						|
	// so the first packet is available in the buffer
 | 
						|
 | 
						|
	// Note: Realistically no tls/http/irc client is going to send so few bytes
 | 
						|
	//       (and no router is going to chunk so small)
 | 
						|
	//       that it cannot reasonably detect the protocol in the first packet
 | 
						|
	//       However, typical MTU is 1,500 and HTTP can have a 2k URL
 | 
						|
	//       so expecting to find the "HTTP/1.1" in the Peek is not always reasonable
 | 
						|
	n := conn.Buffered()
 | 
						|
	firstMsg, err := conn.Peek(n)
 | 
						|
	if nil != err {
 | 
						|
		conn.Close()
 | 
						|
		return
 | 
						|
	}
 | 
						|
	var protocol string
 | 
						|
	// between A and z
 | 
						|
	if firstMsg[0] >= 65 && firstMsg[0] <= 122 {
 | 
						|
		i := bytes.Index(firstMsg, []byte(" /"))
 | 
						|
		if -1 != i {
 | 
						|
			protocol = "HTTP"
 | 
						|
			// very likely HTTP
 | 
						|
			j := bytes.IndexAny(firstMsg, "\r\n")
 | 
						|
			if -1 != j {
 | 
						|
				k := bytes.Index(bytes.ToLower(firstMsg[:j]), []byte("HTTP/1"))
 | 
						|
				if -1 != k {
 | 
						|
					// positively HTTP
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	} else if 0x16 /*22*/ == firstMsg[0] {
 | 
						|
		// Because I don't always remember off the top of my head what the first byte is
 | 
						|
		// http://blog.fourthbit.com/2014/12/23/traffic-analysis-of-an-ssl-slash-tls-session
 | 
						|
		// https://tlseminar.github.io/first-few-milliseconds/
 | 
						|
		// TODO I want to learn about ALPN
 | 
						|
		protocol = "TLS"
 | 
						|
	}
 | 
						|
 | 
						|
	if "" == protocol {
 | 
						|
		// Throw away the first bytes
 | 
						|
		b := make([]byte, 4096)
 | 
						|
		conn.Read(b)
 | 
						|
		fmt.Fprintf(conn, "\n\nWelcome to Sample Chat! You're not an HTTP client, assuming Telnet.\nYou must authenticate via email to participate\n\nEmail: ")
 | 
						|
		wantsServerHello <- conn
 | 
						|
		return
 | 
						|
	} else if "HTTP" != protocol {
 | 
						|
		defer conn.Close()
 | 
						|
		fmt.Fprintf(conn, "\n\nNot yet supported. Try HTTP or Telnet\n\n")
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	demuxHttpClient <- conn
 | 
						|
}
 | 
						|
 | 
						|
func testForHello(netConn net.Conn) {
 | 
						|
	ts := time.Now()
 | 
						|
	fmt.Fprintf(os.Stdout, "[New Connection] (%s) welcome %s\n", ts, netConn.RemoteAddr().String())
 | 
						|
 | 
						|
	m := sync.Mutex{}
 | 
						|
	virgin := true
 | 
						|
 | 
						|
	bufConn := newBufferedConn(netConn)
 | 
						|
	go func() {
 | 
						|
		// Cause first packet to be loaded into buffer
 | 
						|
		_, err := bufConn.Peek(1)
 | 
						|
		if nil != err {
 | 
						|
			panic(err)
 | 
						|
		}
 | 
						|
 | 
						|
		m.Lock()
 | 
						|
		if virgin {
 | 
						|
			virgin = false
 | 
						|
			gotClientHello <- bufConn
 | 
						|
		} else {
 | 
						|
			wantsServerHello <- bufConn
 | 
						|
		}
 | 
						|
		m.Unlock()
 | 
						|
	}()
 | 
						|
 | 
						|
	// Wait for a hello packet of some sort from the client
 | 
						|
	// (obviously this wouldn't work in extremely high latency situations)
 | 
						|
	time.Sleep(250 * 1000000)
 | 
						|
 | 
						|
	// If we still haven't received data from the client
 | 
						|
	// assume that the client must be expecting a welcome from us
 | 
						|
	m.Lock()
 | 
						|
	if virgin {
 | 
						|
		virgin = false
 | 
						|
		// Defer as to not block and prolonging the mutex
 | 
						|
		// (not that those few cycles much matter...)
 | 
						|
		defer fmt.Fprintf(netConn,
 | 
						|
			"\n\nWelcome to Sample Chat! You appear to be using Telnet (http is also available on this port)."+
 | 
						|
				"\nYou must authenticate via email to participate\n\nEmail: ")
 | 
						|
	}
 | 
						|
	m.Unlock()
 | 
						|
}
 | 
						|
 | 
						|
func sendAuthCode(cnf ConfMailer, to string) (string, error) {
 | 
						|
	code, err := genAuthCode()
 | 
						|
	if nil != err {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	// TODO use go text templates with HTML escaping
 | 
						|
	text := "Your authorization code:\n\n" + code
 | 
						|
	html := "Your authorization code:<br><br>" + code
 | 
						|
 | 
						|
	// https://stackoverflow.com/questions/24493116/how-to-send-a-post-request-in-go
 | 
						|
	// https://stackoverflow.com/questions/16673766/basic-http-auth-in-go
 | 
						|
	client := http.Client{}
 | 
						|
 | 
						|
	form := url.Values{}
 | 
						|
	form.Add("from", cnf.From)
 | 
						|
	form.Add("to", to)
 | 
						|
	form.Add("subject", "Sample Chat Auth Code: "+code)
 | 
						|
	form.Add("text", text)
 | 
						|
	form.Add("html", html)
 | 
						|
 | 
						|
	req, err := http.NewRequest("POST", cnf.Url, strings.NewReader(form.Encode()))
 | 
						|
	if nil != err {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	//req.PostForm = form ??
 | 
						|
	req.Header.Add("User-Agent", "golang http.Client - Sample Chat App Authenticator")
 | 
						|
	req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
 | 
						|
	req.SetBasicAuth("api", cnf.ApiKey)
 | 
						|
 | 
						|
	resp, err := client.Do(req)
 | 
						|
	if nil != err {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	defer resp.Body.Close()
 | 
						|
	// Security XXX
 | 
						|
	// we trust mailgun implicitly and this is just a demo
 | 
						|
	// hence no DoS check on body size for now
 | 
						|
	body, err := ioutil.ReadAll(resp.Body)
 | 
						|
	if nil != err {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	if resp.StatusCode < 200 || resp.StatusCode >= 300 || "{" != string(body[0]) {
 | 
						|
		fmt.Fprintf(os.Stdout, "[Mailgun] Uh-oh...\n[Maigun] Baby Brent says: %s\n", body)
 | 
						|
	}
 | 
						|
 | 
						|
	return code, nil
 | 
						|
}
 | 
						|
 | 
						|
func main() {
 | 
						|
	flag.Usage = usage
 | 
						|
	port := flag.Uint("port", 0, "tcp telnet chat port")
 | 
						|
	confname := flag.String("conf", "./config.yml", "yaml config file")
 | 
						|
	flag.Parse()
 | 
						|
 | 
						|
	confstr, err := ioutil.ReadFile(*confname)
 | 
						|
	fmt.Fprintf(os.Stdout, "-conf=%s\n", *confname)
 | 
						|
	if nil != err {
 | 
						|
		fmt.Fprintf(os.Stderr, "%s\nUsing defaults instead\n", err)
 | 
						|
		confstr = []byte("{\"port\":" + strconv.Itoa(int(*port)) + "}")
 | 
						|
	}
 | 
						|
	err = yaml.Unmarshal(confstr, &config)
 | 
						|
	if nil != err {
 | 
						|
		config = Conf{}
 | 
						|
	}
 | 
						|
	if "" == config.RootPath {
 | 
						|
		// TODO Maybe embed the public dir into the binary
 | 
						|
		// (and provide a flag with path for override - like gitea)
 | 
						|
		config.RootPath = "./public"
 | 
						|
	}
 | 
						|
 | 
						|
	// The magical sorting hat
 | 
						|
	virginConns = make(chan net.Conn, 128)
 | 
						|
 | 
						|
	// TCP & Authentication
 | 
						|
	telnetConns := make(map[string]telnetUser)
 | 
						|
	wantsServerHello = make(chan bufferedConn, 128)
 | 
						|
	authTelnet = make(chan telnetUser, 128)
 | 
						|
 | 
						|
	// HTTP & Authentication
 | 
						|
	myAuthReqs := make(map[string]authReq)
 | 
						|
	authReqs = make(chan authReq, 128)
 | 
						|
	valAuthReqs = make(chan authReq, 128)
 | 
						|
	delAuthReqs = make(chan authReq, 128)
 | 
						|
	gotClientHello = make(chan bufferedConn, 128)
 | 
						|
	demuxHttpClient = make(chan bufferedConn, 128)
 | 
						|
 | 
						|
	//myRooms["general"] = make(chan chatMsg, 128)
 | 
						|
	// Note: I had considered dynamically select on channels for rooms.
 | 
						|
	// https://stackoverflow.com/questions/19992334/how-to-listen-to-n-channels-dynamic-select-statement
 | 
						|
	// I don't think that's actually the best approach, but I just wanted to save the link
 | 
						|
 | 
						|
	broadcastMsg = make(chan chatMsg, 128)
 | 
						|
	myChatHist.msgs = make([]*chatMsg, 128)
 | 
						|
 | 
						|
	var addr string
 | 
						|
	if 0 != int(*port) {
 | 
						|
		addr = config.Addr + ":" + strconv.Itoa(int(*port))
 | 
						|
	} else {
 | 
						|
		addr = config.Addr + ":" + strconv.Itoa(int(config.Port))
 | 
						|
	}
 | 
						|
 | 
						|
	// https://golang.org/pkg/net/#Conn
 | 
						|
	sock, err := net.Listen("tcp", addr)
 | 
						|
	if nil != err {
 | 
						|
		fmt.Fprintf(os.Stderr, "Couldn't bind to TCP socket %q: %s\n", addr, err)
 | 
						|
		os.Exit(2)
 | 
						|
	}
 | 
						|
	fmt.Println("Listening on", addr)
 | 
						|
 | 
						|
	go func() {
 | 
						|
		for {
 | 
						|
			conn, err := sock.Accept()
 | 
						|
			if err != nil {
 | 
						|
				// Not sure what kind of error this could be or how it could happen.
 | 
						|
				// Could a connection abort or end before it's handled?
 | 
						|
				fmt.Fprintf(os.Stderr, "Error accepting connection:\n%s\n", err)
 | 
						|
			}
 | 
						|
			virginConns <- conn
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	// Learning by Example
 | 
						|
	// https://github.com/emicklei/go-restful/blob/master/examples/restful-multi-containers.go
 | 
						|
	// https://github.com/emicklei/go-restful/blob/master/examples/restful-basic-authentication.go
 | 
						|
	// https://github.com/emicklei/go-restful/blob/master/examples/restful-serve-static.go
 | 
						|
	// https://github.com/emicklei/go-restful/blob/master/examples/restful-pre-post-filters.go
 | 
						|
	container := restful.NewContainer()
 | 
						|
 | 
						|
	wsStatic := new(restful.WebService)
 | 
						|
	wsStatic.Path("/")
 | 
						|
	wsStatic.Route(wsStatic.GET("/").To(serveStatic))
 | 
						|
	wsStatic.Route(wsStatic.GET("/{subpath:*}").To(serveStatic))
 | 
						|
	container.Add(wsStatic)
 | 
						|
 | 
						|
	cors := restful.CrossOriginResourceSharing{ExposeHeaders: []string{"Authorization"}, CookiesAllowed: false, Container: container}
 | 
						|
	wsApi := new(restful.WebService)
 | 
						|
	wsApi.Path("/api").Consumes(restful.MIME_JSON).Produces(restful.MIME_JSON).Filter(cors.Filter)
 | 
						|
	wsApi.Route(wsApi.GET("/hello").To(serveHello))
 | 
						|
	wsApi.Route(wsApi.POST("/sessions").To(requestAuth))
 | 
						|
	wsApi.Route(wsApi.POST("/sessions/{cid}").To(issueToken))
 | 
						|
	wsApi.Route(wsApi.GET("/rooms/{room}").Filter(requireToken).To(listMsgs))
 | 
						|
	wsApi.Route(wsApi.POST("/rooms/{room}").Filter(requireToken).To(postMsg))
 | 
						|
	container.Add(wsApi)
 | 
						|
 | 
						|
	server := &http.Server{
 | 
						|
		Addr:    addr,
 | 
						|
		Handler: container,
 | 
						|
	}
 | 
						|
	myHttpServer := newHttpServer(sock)
 | 
						|
	go func() {
 | 
						|
		server.Serve(myHttpServer)
 | 
						|
	}()
 | 
						|
 | 
						|
	// Main event loop handling most access to shared data
 | 
						|
	for {
 | 
						|
		select {
 | 
						|
		case conn := <-virginConns:
 | 
						|
			// This is short lived
 | 
						|
			go testForHello(conn)
 | 
						|
		case u := <-authTelnet:
 | 
						|
			// allow to receive messages
 | 
						|
			// (and be counted among the users)
 | 
						|
			_, ok := telnetConns[u.email]
 | 
						|
			if ok {
 | 
						|
				// this is a blocking channel, and that's important
 | 
						|
				cleanTelnet <- telnetConns[u.email]
 | 
						|
			}
 | 
						|
			telnetConns[u.email] = u
 | 
						|
			// is chan chan the right way to handle this?
 | 
						|
			u.userCount <- len(telnetConns)
 | 
						|
			broadcastMsg <- chatMsg{
 | 
						|
				sender:     nil,
 | 
						|
				Message:    fmt.Sprintf("<%s> joined #general\r\n", u.email),
 | 
						|
				ReceivedAt: time.Now(),
 | 
						|
				Channel:    "general",
 | 
						|
				User:       "system",
 | 
						|
			}
 | 
						|
		case ar := <-authReqs:
 | 
						|
			myAuthReqs[ar.Cid] = ar
 | 
						|
		case ar := <-valAuthReqs:
 | 
						|
			// TODO In this case it's probably more conventional (and efficient) to
 | 
						|
			// use struct with a mutex and the authReqs map than a chan chan
 | 
						|
			av, ok := myAuthReqs[ar.Cid]
 | 
						|
			//ar.Chan <- nil // TODO
 | 
						|
			if ok {
 | 
						|
				ar.Chan <- av
 | 
						|
			} else {
 | 
						|
				// sending empty object so that I can still send a copy
 | 
						|
				// rather than a pointer above. Maybe not the right way
 | 
						|
				// to do this, but it works for now.
 | 
						|
				ar.Chan <- authReq{}
 | 
						|
			}
 | 
						|
		case ar := <-delAuthReqs:
 | 
						|
			delete(myAuthReqs, ar.Cid)
 | 
						|
		case bufConn := <-wantsServerHello:
 | 
						|
			go handleTelnetConn(bufConn)
 | 
						|
		case u := <-cleanTelnet:
 | 
						|
			close(u.newMsg)
 | 
						|
			// we can safely ignore this error, if any
 | 
						|
			u.bufConn.Close()
 | 
						|
			delete(telnetConns, u.email)
 | 
						|
		case bufConn := <-gotClientHello:
 | 
						|
			go muxTcp(bufConn)
 | 
						|
		case bufConn := <-demuxHttpClient:
 | 
						|
			// this will be Accept()ed immediately by the go-restful container
 | 
						|
			// NOTE: we don't store these HTTP connections for broadcast
 | 
						|
			// since we manage the session by HTTP Auth Bearer rather than TCP
 | 
						|
			myHttpServer.chans <- bufConn
 | 
						|
		case msg := <-broadcastMsg:
 | 
						|
			// copy comes in, pointer gets saved (and not GC'd, I hope)
 | 
						|
			myChatHist.msgs[myChatHist.i] = &msg
 | 
						|
			myChatHist.i += 1
 | 
						|
			if myChatHist.c < cap(myChatHist.msgs) {
 | 
						|
				myChatHist.c += 1
 | 
						|
			}
 | 
						|
			myChatHist.i %= len(myChatHist.msgs)
 | 
						|
 | 
						|
			// print the system message (the "log")
 | 
						|
			t := msg.ReceivedAt
 | 
						|
			tf := "%d-%02d-%02d %02d:%02d:%02d (%s)"
 | 
						|
			var sender string
 | 
						|
			if nil != msg.sender {
 | 
						|
				sender = msg.sender.RemoteAddr().String()
 | 
						|
			} else {
 | 
						|
				sender = "system"
 | 
						|
			}
 | 
						|
			// Tangential thought:
 | 
						|
			// I wonder if we could use IP detection to get a Telnet client's tz
 | 
						|
			// ... could probably make time for this in the authentication loop
 | 
						|
			zone, _ := msg.ReceivedAt.Zone()
 | 
						|
			fmt.Fprintf(os.Stdout, tf+" [%s] (%s): %s\r\n",
 | 
						|
				t.Year(), t.Month(), t.Day(),
 | 
						|
				t.Hour(), t.Minute(), t.Second(), zone,
 | 
						|
				sender,
 | 
						|
				msg.User, msg.Message)
 | 
						|
 | 
						|
			for _, u := range telnetConns {
 | 
						|
				// Don't echo back to the original client
 | 
						|
				if msg.sender == u.bufConn {
 | 
						|
					continue
 | 
						|
				}
 | 
						|
 | 
						|
				msg := fmt.Sprintf(tf+" [%s]: %s", t.Year(), t.Month(), t.Day(), t.Hour(),
 | 
						|
					t.Minute(), t.Second(), zone, msg.User, msg.Message)
 | 
						|
				select {
 | 
						|
				case u.newMsg <- msg:
 | 
						|
					// all is well, client was ready to receive
 | 
						|
				default:
 | 
						|
					// Rate Limit: Reasonable poor man's DoS prevention (Part 2)
 | 
						|
					// This client's send channel buffer is full.
 | 
						|
					// It is consuming data too slowly. It may be malicious.
 | 
						|
					// In the case that it's experiencing network issues,
 | 
						|
					// well, these things happen when you're having network issues.
 | 
						|
					// It can reconnect.
 | 
						|
					cleanTelnet <- u
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 |