package main // Lot's of learning right out of the gate: // https://stackoverflow.com/questions/51472020/how-to-get-the-size-of-available-tcp-data import ( "bufio" "flag" "fmt" "io" "io/ioutil" "net" "os" "strconv" "sync" "time" "gopkg.in/yaml.v2" ) type Conf struct { Port uint `yaml:"port,omitempty"` Mailer struct { ApiKey string `yaml:"api_key,omitempty"` From string `yaml:"from,omitempty"` } } type bufferedConn struct { r *bufio.Reader rout io.Reader net.Conn } func newBufferedConn(c net.Conn) bufferedConn { return bufferedConn{bufio.NewReader(c), nil, c} } func (b bufferedConn) Peek(n int) ([]byte, error) { return b.r.Peek(n) } func (b bufferedConn) Buffered() (int) { return b.r.Buffered() } func (b bufferedConn) Read(p []byte) (int, error) { if b.rout != nil { return b.rout.Read(p) } return b.r.Read(p) } type myMsg struct { sender net.Conn bytes []byte receivedAt time.Time } var firstMsgs chan myMsg var myMsgs chan myMsg var myUnsortedConns map[net.Conn]bool var myRawConns map[net.Conn]bool var newConns chan net.Conn func usage() { fmt.Fprintf(os.Stderr, "\nusage: go run chatserver.go\n") flag.PrintDefaults(); fmt.Println() os.Exit(1) } func handleRaw(conn bufferedConn) { // TODO // What happens if this is being read from range // when it's being added here (data race)? // Should I use a channel here instead? // TODO see https://jameshfisher.com/2017/04/18/golang-tcp-server.html myRawConns[conn] = true // Handle all subsequent packets buf := make([]byte, 1024) for { fmt.Fprintf(os.Stdout, "[raw] Waiting for message...\n"); count, err := conn.Read(buf) if nil != err { if io.EOF != err { fmt.Fprintf(os.Stderr, "Non-EOF socket error: %s\n", err) } fmt.Fprintf(os.Stdout, "Ending socket\n") break } // Fun fact: if the buffer's current length (not capacity) is 0 // then the Read returns 0 without error if 0 == count { fmt.Fprintf(os.Stdout, "Weird") continue } fmt.Fprintf(os.Stdout, "Queing message...\n"); myMsgs <- myMsg{ receivedAt: time.Now(), sender: conn, bytes: buf[0:count], } } } func handleSorted(conn bufferedConn) { // at this piont we've already at least one byte via Peek() // so the first packet is available in the buffer n := conn.Buffered() firstMsg, err := conn.Peek(n) if nil != err { panic(err) } firstMsgs <- myMsg{ receivedAt: time.Now(), sender: conn, bytes: firstMsg, } // TODO // * TCP-CHAT // * HTTP // * TLS // Handle all subsequent packets buf := make([]byte, 1024) for { fmt.Fprintf(os.Stdout, "[sortable] Waiting for message...\n"); count, err := conn.Read(buf) if nil != err { if io.EOF != err { fmt.Fprintf(os.Stderr, "Non-EOF socket error: %s\n", err) } fmt.Fprintf(os.Stdout, "Ending socket\n") break } // Fun fact: if the buffer's current length (not capacity) is 0 // then the Read returns 0 without error if 0 == count { // fmt.Fprintf(os.Stdout, "Weird") continue } myMsgs <- myMsg{ receivedAt: time.Now(), sender: conn, bytes: buf[0:count], } } } // TODO https://github.com/polvi/sni func handleConnection(conn net.Conn) { fmt.Fprintf(os.Stdout, "Accepting socket\n") m := sync.Mutex{} virgin := true myUnsortedConns[conn] = true // Why don't these work? //buf := make([]byte, 0, 1024) //buf := []byte{} // But this does bufConn := newBufferedConn(conn) go func() { // Handle First Packet fmsg, err := bufConn.Peek(1) if nil != err { panic(err) } fmt.Fprintf(os.Stdout, "[First Byte] %s\n", fmsg) m.Lock(); if virgin { virgin = false go handleSorted(bufConn) } else { go handleRaw(bufConn) } m.Unlock(); }() time.Sleep(250 * 1000000) // If we still haven't received data from the client // assume that the client must be expecting a welcome from us m.Lock() if virgin { virgin = false // don't block for this // let it be handled after the unlock defer fmt.Fprintf(conn, "Welcome! This is an open relay chat server. There is no security yet.\n") } m.Unlock() } func main() { flag.Usage = usage port := flag.Uint("telnet-port", 0, "tcp telnet chat port") confname := flag.String("config", "./config.yml", "ymal config file") flag.Parse() var config Conf confstr, err := ioutil.ReadFile(*confname) fmt.Fprintf(os.Stdout, "-config=%s\n", *confname) if nil != err { fmt.Fprintf(os.Stderr, "%s\nUsing defaults instead\n", err) confstr = []byte("{\"port\":" + strconv.Itoa(int(*port)) + "}") } err = yaml.Unmarshal(confstr, &config) if nil != err { config = Conf{} } firstMsgs = make(chan myMsg, 128) myMsgs = make(chan myMsg, 128) newConns = make(chan net.Conn, 128) myRawConns = make(map[net.Conn]bool) myUnsortedConns = make(map[net.Conn]bool) var addr string if 0 != int(*port) { addr = ":" + strconv.Itoa(int(*port)) } else { addr = ":" + strconv.Itoa(int(config.Port)) } // https://golang.org/pkg/net/#Conn sock, err := net.Listen("tcp", addr) if nil != err { fmt.Fprintf(os.Stderr, "Couldn't bind to TCP socket %q: %s\n", addr, err) os.Exit(2) } fmt.Println("Listening on", addr); go func() { for { conn, err := sock.Accept() if err != nil { // Not sure what kind of error this could be or how it could happen. // Could a connection abort or end before it's handled? fmt.Fprintf(os.Stderr, "Error accepting connection:\n%s\n", err) } newConns <- conn } }() for { select { case conn := <- newConns: ts := time.Now() fmt.Fprintf(os.Stdout, "[Handle New Connection] [Timestamp] %s\n", ts) go handleConnection(conn) case msg := <- myMsgs: ts, err := msg.receivedAt.MarshalJSON() if nil != err { fmt.Fprintf(os.Stderr, "[Error] %s\n", err) } fmt.Fprintf(os.Stdout, "[Timestamp] %s\n", ts) fmt.Fprintf(os.Stdout, "[Remote] %s\n", msg.sender.RemoteAddr().String()) fmt.Fprintf(os.Stdout, "[Message] %s\n", msg.bytes); for conn, _ := range myRawConns { if msg.sender == conn { continue } // backlogged connections could prevent a next write, // so this should be refactored into a goroutine // And what to do about slow clients that get behind (or DoS)? // SetDeadTime and Disconnect them? conn.Write(msg.bytes) } case msg := <- firstMsgs: fmt.Fprintf(os.Stdout, "f [First Message]\n") ts, err := msg.receivedAt.MarshalJSON() if nil != err { fmt.Fprintf(os.Stderr, "f [Error] %s\n", err) } fmt.Fprintf(os.Stdout, "f [Timestamp] %s\n", ts) fmt.Fprintf(os.Stdout, "f [Remote] %s\n", msg.sender.RemoteAddr().String()) fmt.Fprintf(os.Stdout, "f [Message] %s\n", msg.bytes); } } }