chat.go/chatserver.go

490 lines
14 KiB
Go

package main
// TODO learn more about chan chan's
// http://marcio.io/2015/07/handling-1-million-requests-per-minute-with-golang/
import (
"bufio"
"bytes"
"crypto/rand"
"encoding/base64"
"flag"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/emicklei/go-restful"
"gopkg.in/yaml.v2"
)
// I'm not sure how to pass nested structs, so I de-nested this.
// TODO Learn if passing nested structs is desirable?
type Conf struct {
Addr string `yaml:"addr,omitempty"`
Port uint `yaml:"port,omitempty"`
Mailer ConfMailer
RootPath string `yaml:"root_path,omitempty"`
}
type ConfMailer struct {
Url string `yaml:"url,omitempty"`
ApiKey string `yaml:"api_key,omitempty"`
From string `yaml:"from,omitempty"`
}
// So we can peek at net.Conn, which we can't do natively
// https://stackoverflow.com/questions/51472020/how-to-get-the-size-of-available-tcp-data
type bufferedConn struct {
r *bufio.Reader
//rout *io.Reader // See https://github.com/polvi/sni/blob/master/sni.go#L135
net.Conn
}
func newBufferedConn(c net.Conn) bufferedConn {
return bufferedConn{bufio.NewReader(c), c}
}
func (b bufferedConn) Peek(n int) ([]byte, error) {
return b.r.Peek(n)
}
func (b bufferedConn) Buffered() int {
return b.r.Buffered()
}
func (b bufferedConn) Read(p []byte) (int, error) {
/*
if b.rout != nil {
return b.rout.Read(p)
}
*/
return b.r.Read(p)
}
type chatMsg struct {
sender net.Conn
Message string `json:"message"`
ReceivedAt time.Time `json:"received_at"`
Channel string `json:"channel"`
User string `json:"user"`
}
// Poor-Man's container/ring (circular buffer)
type chatHist struct {
msgs []*chatMsg
i int // current index
c int // current count (number of elements)
}
// Multi-use
var config Conf
var virginConns chan net.Conn
var gotClientHello chan bufferedConn
var myChatHist chatHist
var broadcastMsg chan chatMsg
// Telnet
var wantsServerHello chan bufferedConn
var authTelnet chan telnetUser
var cleanTelnet chan telnetUser // intentionally blocking
// HTTP
var demuxHttpClient chan bufferedConn
var authReqs chan authReq
var valAuthReqs chan authReq
var delAuthReqs chan authReq
func usage() {
fmt.Fprintf(os.Stderr, "\nusage: go run chatserver*.go\n")
flag.PrintDefaults()
fmt.Println()
os.Exit(1)
}
// https://blog.questionable.services/article/generating-secure-random-numbers-crypto-rand/
func genAuthCode() (string, error) {
n := 12
b := make([]byte, n)
_, err := rand.Read(b)
// Note that err == nil only if we read len(b) bytes.
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
func muxTcp(conn bufferedConn) {
// Wish List for protocol detection
// * PROXY protocol (and loop)
// * HTTP CONNECT (proxy) (and loop)
// * tls (and loop) https://github.com/polvi/sni
// * http/ws
// * irc
// * fallback to telnet
// At this piont we've already at least one byte via Peek()
// so the first packet is available in the buffer
// Note: Realistically no tls/http/irc client is going to send so few bytes
// (and no router is going to chunk so small)
// that it cannot reasonably detect the protocol in the first packet
// However, typical MTU is 1,500 and HTTP can have a 2k URL
// so expecting to find the "HTTP/1.1" in the Peek is not always reasonable
n := conn.Buffered()
firstMsg, err := conn.Peek(n)
if nil != err {
conn.Close()
return
}
var protocol string
// between A and z
if firstMsg[0] >= 65 && firstMsg[0] <= 122 {
i := bytes.Index(firstMsg, []byte(" /"))
if -1 != i {
protocol = "HTTP"
// very likely HTTP
j := bytes.IndexAny(firstMsg, "\r\n")
if -1 != j {
k := bytes.Index(bytes.ToLower(firstMsg[:j]), []byte("HTTP/1"))
if -1 != k {
// positively HTTP
}
}
}
} else if 0x16 /*22*/ == firstMsg[0] {
// Because I don't always remember off the top of my head what the first byte is
// http://blog.fourthbit.com/2014/12/23/traffic-analysis-of-an-ssl-slash-tls-session
// https://tlseminar.github.io/first-few-milliseconds/
// TODO I want to learn about ALPN
protocol = "TLS"
}
if "" == protocol {
// Throw away the first bytes
b := make([]byte, 4096)
conn.Read(b)
fmt.Fprintf(conn, "\n\nWelcome to Sample Chat! You're not an HTTP client, assuming Telnet.\nYou must authenticate via email to participate\n\nEmail: ")
wantsServerHello <- conn
return
} else if "HTTP" != protocol {
defer conn.Close()
fmt.Fprintf(conn, "\n\nNot yet supported. Try HTTP or Telnet\n\n")
return
}
demuxHttpClient <- conn
}
func testForHello(netConn net.Conn) {
ts := time.Now()
fmt.Fprintf(os.Stdout, "[New Connection] (%s) welcome %s\n", ts, netConn.RemoteAddr().String())
m := sync.Mutex{}
virgin := true
bufConn := newBufferedConn(netConn)
go func() {
// Cause first packet to be loaded into buffer
_, err := bufConn.Peek(1)
if nil != err {
panic(err)
}
m.Lock()
if virgin {
virgin = false
gotClientHello <- bufConn
} else {
wantsServerHello <- bufConn
}
m.Unlock()
}()
// Wait for a hello packet of some sort from the client
// (obviously this wouldn't work in extremely high latency situations)
time.Sleep(250 * 1000000)
// If we still haven't received data from the client
// assume that the client must be expecting a welcome from us
m.Lock()
if virgin {
virgin = false
// Defer as to not block and prolonging the mutex
// (not that those few cycles much matter...)
defer fmt.Fprintf(netConn,
"\n\nWelcome to Sample Chat! You appear to be using Telnet (http is also available on this port)."+
"\nYou must authenticate via email to participate\n\nEmail: ")
}
m.Unlock()
}
func sendAuthCode(cnf ConfMailer, to string) (string, error) {
code, err := genAuthCode()
if nil != err {
return "", err
}
// TODO use go text templates with HTML escaping
text := "Your authorization code:\n\n" + code
html := "Your authorization code:<br><br>" + code
// https://stackoverflow.com/questions/24493116/how-to-send-a-post-request-in-go
// https://stackoverflow.com/questions/16673766/basic-http-auth-in-go
client := http.Client{}
form := url.Values{}
form.Add("from", cnf.From)
form.Add("to", to)
form.Add("subject", "Sample Chat Auth Code: "+code)
form.Add("text", text)
form.Add("html", html)
req, err := http.NewRequest("POST", cnf.Url, strings.NewReader(form.Encode()))
if nil != err {
return "", err
}
//req.PostForm = form ??
req.Header.Add("User-Agent", "golang http.Client - Sample Chat App Authenticator")
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("api", cnf.ApiKey)
resp, err := client.Do(req)
if nil != err {
return "", err
}
defer resp.Body.Close()
// Security XXX
// we trust mailgun implicitly and this is just a demo
// hence no DoS check on body size for now
body, err := ioutil.ReadAll(resp.Body)
if nil != err {
return "", err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 || "{" != string(body[0]) {
fmt.Fprintf(os.Stdout, "[Mailgun] Uh-oh...\n[Maigun] Baby Brent says: %s\n", body)
}
return code, nil
}
func main() {
flag.Usage = usage
port := flag.Uint("port", 0, "tcp telnet chat port")
confname := flag.String("conf", "./config.yml", "yaml config file")
flag.Parse()
confstr, err := ioutil.ReadFile(*confname)
fmt.Fprintf(os.Stdout, "-conf=%s\n", *confname)
if nil != err {
fmt.Fprintf(os.Stderr, "%s\nUsing defaults instead\n", err)
confstr = []byte("{\"port\":" + strconv.Itoa(int(*port)) + "}")
}
err = yaml.Unmarshal(confstr, &config)
if nil != err {
config = Conf{}
}
if "" == config.RootPath {
// TODO Maybe embed the public dir into the binary
// (and provide a flag with path for override - like gitea)
config.RootPath = "./public"
}
// The magical sorting hat
virginConns = make(chan net.Conn, 128)
// TCP & Authentication
telnetConns := make(map[string]telnetUser)
wantsServerHello = make(chan bufferedConn, 128)
authTelnet = make(chan telnetUser, 128)
// HTTP & Authentication
myAuthReqs := make(map[string]authReq)
authReqs = make(chan authReq, 128)
valAuthReqs = make(chan authReq, 128)
delAuthReqs = make(chan authReq, 128)
gotClientHello = make(chan bufferedConn, 128)
demuxHttpClient = make(chan bufferedConn, 128)
//myRooms["general"] = make(chan chatMsg, 128)
// Note: I had considered dynamically select on channels for rooms.
// https://stackoverflow.com/questions/19992334/how-to-listen-to-n-channels-dynamic-select-statement
// I don't think that's actually the best approach, but I just wanted to save the link
broadcastMsg = make(chan chatMsg, 128)
myChatHist.msgs = make([]*chatMsg, 128)
var addr string
if 0 != int(*port) {
addr = config.Addr + ":" + strconv.Itoa(int(*port))
} else {
addr = config.Addr + ":" + strconv.Itoa(int(config.Port))
}
// https://golang.org/pkg/net/#Conn
sock, err := net.Listen("tcp", addr)
if nil != err {
fmt.Fprintf(os.Stderr, "Couldn't bind to TCP socket %q: %s\n", addr, err)
os.Exit(2)
}
fmt.Println("Listening on", addr)
go func() {
for {
conn, err := sock.Accept()
if err != nil {
// Not sure what kind of error this could be or how it could happen.
// Could a connection abort or end before it's handled?
fmt.Fprintf(os.Stderr, "Error accepting connection:\n%s\n", err)
}
virginConns <- conn
}
}()
// Learning by Example
// https://github.com/emicklei/go-restful/blob/master/examples/restful-multi-containers.go
// https://github.com/emicklei/go-restful/blob/master/examples/restful-basic-authentication.go
// https://github.com/emicklei/go-restful/blob/master/examples/restful-serve-static.go
// https://github.com/emicklei/go-restful/blob/master/examples/restful-pre-post-filters.go
container := restful.NewContainer()
wsStatic := new(restful.WebService)
wsStatic.Path("/")
wsStatic.Route(wsStatic.GET("/").To(serveStatic))
wsStatic.Route(wsStatic.GET("/{subpath:*}").To(serveStatic))
container.Add(wsStatic)
cors := restful.CrossOriginResourceSharing{ExposeHeaders: []string{"Authorization"}, CookiesAllowed: false, Container: container}
wsApi := new(restful.WebService)
wsApi.Path("/api").Consumes(restful.MIME_JSON).Produces(restful.MIME_JSON).Filter(cors.Filter)
wsApi.Route(wsApi.GET("/hello").To(serveHello))
wsApi.Route(wsApi.POST("/sessions").To(requestAuth))
wsApi.Route(wsApi.POST("/sessions/{cid}").To(issueToken))
wsApi.Route(wsApi.GET("/rooms/{room}").Filter(requireToken).To(listMsgs))
wsApi.Route(wsApi.POST("/rooms/{room}").Filter(requireToken).To(postMsg))
container.Add(wsApi)
server := &http.Server{
Addr: addr,
Handler: container,
}
myHttpServer := newHttpServer(sock)
go func() {
server.Serve(myHttpServer)
}()
// Main event loop handling most access to shared data
for {
select {
case conn := <-virginConns:
// This is short lived
go testForHello(conn)
case u := <-authTelnet:
// allow to receive messages
// (and be counted among the users)
_, ok := telnetConns[u.email]
if ok {
// this is a blocking channel, and that's important
cleanTelnet <- telnetConns[u.email]
}
telnetConns[u.email] = u
// is chan chan the right way to handle this?
u.userCount <- len(telnetConns)
broadcastMsg <- chatMsg{
sender: nil,
Message: fmt.Sprintf("<%s> joined #general\r\n", u.email),
ReceivedAt: time.Now(),
Channel: "general",
User: "system",
}
case ar := <-authReqs:
myAuthReqs[ar.Cid] = ar
case ar := <-valAuthReqs:
// TODO In this case it's probably more conventional (and efficient) to
// use struct with a mutex and the authReqs map than a chan chan
av, ok := myAuthReqs[ar.Cid]
//ar.Chan <- nil // TODO
if ok {
ar.Chan <- av
} else {
// sending empty object so that I can still send a copy
// rather than a pointer above. Maybe not the right way
// to do this, but it works for now.
ar.Chan <- authReq{}
}
case ar := <-delAuthReqs:
delete(myAuthReqs, ar.Cid)
case bufConn := <-wantsServerHello:
go handleTelnetConn(bufConn)
case u := <-cleanTelnet:
close(u.newMsg)
// we can safely ignore this error, if any
u.bufConn.Close()
delete(telnetConns, u.email)
case bufConn := <-gotClientHello:
go muxTcp(bufConn)
case bufConn := <-demuxHttpClient:
// this will be Accept()ed immediately by the go-restful container
// NOTE: we don't store these HTTP connections for broadcast
// since we manage the session by HTTP Auth Bearer rather than TCP
myHttpServer.chans <- bufConn
case msg := <-broadcastMsg:
// copy comes in, pointer gets saved (and not GC'd, I hope)
myChatHist.msgs[myChatHist.i] = &msg
myChatHist.i += 1
if myChatHist.c < cap(myChatHist.msgs) {
myChatHist.c += 1
}
myChatHist.i %= len(myChatHist.msgs)
// print the system message (the "log")
t := msg.ReceivedAt
tf := "%d-%02d-%02d %02d:%02d:%02d (%s)"
var sender string
if nil != msg.sender {
sender = msg.sender.RemoteAddr().String()
} else {
sender = "system"
}
// Tangential thought:
// I wonder if we could use IP detection to get a Telnet client's tz
// ... could probably make time for this in the authentication loop
zone, _ := msg.ReceivedAt.Zone()
fmt.Fprintf(os.Stdout, tf+" [%s] (%s): %s\r\n",
t.Year(), t.Month(), t.Day(),
t.Hour(), t.Minute(), t.Second(), zone,
sender,
msg.User, msg.Message)
for _, u := range telnetConns {
// Don't echo back to the original client
if msg.sender == u.bufConn {
continue
}
msg := fmt.Sprintf(tf+" [%s]: %s", t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), zone, msg.User, msg.Message)
select {
case u.newMsg <- msg:
// all is well, client was ready to receive
default:
// Rate Limit: Reasonable poor man's DoS prevention (Part 2)
// This client's send channel buffer is full.
// It is consuming data too slowly. It may be malicious.
// In the case that it's experiencing network issues,
// well, these things happen when you're having network issues.
// It can reconnect.
cleanTelnet <- u
}
}
}
}
}