An example chat server in golang.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

427 lines
11 KiB

package main
// TODO learn about chan chan's
// http://marcio.io/2015/07/handling-1-million-requests-per-minute-with-golang/
import (
"bufio"
6 years ago
"crypto/rand"
"encoding/base64"
"flag"
"fmt"
"io"
"io/ioutil"
"net"
6 years ago
"net/http"
"net/url"
"os"
"strconv"
6 years ago
"strings"
"sync"
"time"
"gopkg.in/yaml.v2"
)
// I'm not sure how to pass nested structs, so I de-nested this.
// TODO: Learn if passing nested structs is desirable?
type Conf struct {
Port uint `yaml:"port,omitempty"`
Mailer ConfMailer
}
6 years ago
type ConfMailer struct {
Url string `yaml:"url,omitempty"`
ApiKey string `yaml:"api_key,omitempty"`
From string `yaml:"from,omitempty"`
}
// So we can peek at net.Conn, which we can't do natively
// https://stackoverflow.com/questions/51472020/how-to-get-the-size-of-available-tcp-data
type bufferedConn struct {
r *bufio.Reader
rout io.Reader
net.Conn
}
func newBufferedConn(c net.Conn) bufferedConn {
return bufferedConn{bufio.NewReader(c), nil, c}
}
func (b bufferedConn) Peek(n int) ([]byte, error) {
return b.r.Peek(n)
}
func (b bufferedConn) Buffered() (int) {
return b.r.Buffered()
}
func (b bufferedConn) Read(p []byte) (int, error) {
if b.rout != nil {
return b.rout.Read(p)
}
return b.r.Read(p)
}
// Just making these all globals right now
// because... I can clean it up later
type myMsg struct {
sender net.Conn
bytes []byte
receivedAt time.Time
6 years ago
channel string
}
var firstMsgs chan myMsg
6 years ago
var myChans map[string](chan myMsg)
var myMsgs chan myMsg
var myUnsortedConns map[net.Conn]bool
var myRawConns map[net.Conn]bool
var newConns chan net.Conn
var newTcpChat chan bufferedConn
var delTcpChat chan bufferedConn
var newHttpChat chan bufferedConn
var delHttpChat chan bufferedConn
func usage() {
fmt.Fprintf(os.Stderr, "\nusage: go run chatserver.go\n")
flag.PrintDefaults();
fmt.Println()
os.Exit(1)
}
6 years ago
// https://blog.questionable.services/article/generating-secure-random-numbers-crypto-rand/
func genAuthCode() (string, error) {
n := 12
b := make([]byte, n)
_, err := rand.Read(b)
// Note that err == nil only if we read len(b) bytes.
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
func handleRaw(conn bufferedConn) {
// TODO
// What happens if this is being read from range
// when it's being added here (data race)?
// Should I use a channel here instead?
// TODO see https://jameshfisher.com/2017/04/18/golang-tcp-server.html
6 years ago
var email string
var code string
var authn bool
// Handle all subsequent packets
6 years ago
buffer := make([]byte, 1024)
for {
fmt.Fprintf(os.Stdout, "[raw] Waiting for message...\n");
6 years ago
count, err := conn.Read(buffer)
if nil != err {
if io.EOF != err {
fmt.Fprintf(os.Stderr, "Non-EOF socket error: %s\n", err)
}
fmt.Fprintf(os.Stdout, "Ending socket\n")
6 years ago
// TODO put this in a channel to prevent data races
delTcpChat <- conn
break
}
6 years ago
buf := buffer[:count]
// Fun fact: if the buffer's current length (not capacity) is 0
// then the Read returns 0 without error
if 0 == count {
fmt.Fprintf(os.Stdout, "Weird")
6 years ago
break
}
if !authn {
if "" == email {
fmt.Fprintf(os.Stdout, "buf{%s}\n", buf[:count])
// TODO use safer email testing
email = strings.TrimSpace(string(buf[:count]))
emailParts := strings.Split(email, "@")
if 2 != len(emailParts) {
fmt.Fprintf(conn, "Email: ")
continue
}
fmt.Fprintf(os.Stdout, "email: '%v'\n", []byte(email))
code, err = sendAuthCode(config.Mailer, strings.TrimSpace(email))
if nil != err {
// TODO handle better
panic(err)
}
fmt.Fprintf(conn, "Auth Code: ")
continue
}
if code != strings.TrimSpace(string(buf[:count])) {
fmt.Fprintf(conn, "Incorrect Code\nAuth Code: ")
} else {
authn = true
fmt.Fprintf(conn, "Welcome to #general! (TODO `/help' for list of commands)\n")
// TODO number of users
//fmt.Fprintf(conn, "Welcome to #general! TODO `/list' to see channels. `/join chname' to switch.\n")
}
continue
}
6 years ago
fmt.Fprintf(os.Stdout, "Queing message...\n");
//myChans["general"] <- myMsg{
myMsgs <- myMsg{
receivedAt: time.Now(),
sender: conn,
bytes: buf[0:count],
6 years ago
channel: "general",
}
}
}
func handleSorted(conn bufferedConn) {
// at this piont we've already at least one byte via Peek()
// so the first packet is available in the buffer
n := conn.Buffered()
firstMsg, err := conn.Peek(n)
if nil != err {
panic(err)
}
firstMsgs <- myMsg{
receivedAt: time.Now(),
sender: conn,
bytes: firstMsg,
6 years ago
channel: "general",
}
// TODO
// * TCP-CHAT
// * HTTP
// * TLS
// Handle all subsequent packets
buf := make([]byte, 1024)
for {
fmt.Fprintf(os.Stdout, "[sortable] Waiting for message...\n");
count, err := conn.Read(buf)
if nil != err {
if io.EOF != err {
fmt.Fprintf(os.Stderr, "Non-EOF socket error: %s\n", err)
}
fmt.Fprintf(os.Stdout, "Ending socket\n")
break
}
// Fun fact: if the buffer's current length (not capacity) is 0
// then the Read returns 0 without error
if 0 == count {
// fmt.Fprintf(os.Stdout, "Weird")
continue
}
//myChans["general"] <- myMsg{
myMsgs <- myMsg{
receivedAt: time.Now(),
sender: conn,
bytes: buf[0:count],
6 years ago
channel: "general",
}
}
}
// TODO https://github.com/polvi/sni
6 years ago
func handleConnection(netConn net.Conn) {
fmt.Fprintf(os.Stdout, "Accepting socket\n")
m := sync.Mutex{}
virgin := true
// Why don't these work?
//buf := make([]byte, 0, 1024)
//buf := []byte{}
// But this does
6 years ago
bufConn := newBufferedConn(netConn)
myUnsortedConns[bufConn] = true
go func() {
// Handle First Packet
fmsg, err := bufConn.Peek(1)
if nil != err {
panic(err)
}
fmt.Fprintf(os.Stdout, "[First Byte] %s\n", fmsg)
m.Lock();
if virgin {
virgin = false
newHttpChat <- bufConn
} else {
6 years ago
// TODO probably needs to go into a channel
newTcpChat <- bufConn
}
m.Unlock();
}()
time.Sleep(250 * 1000000)
// If we still haven't received data from the client
// assume that the client must be expecting a welcome from us
m.Lock()
if virgin {
virgin = false
// don't block for this
// let it be handled after the unlock
6 years ago
defer fmt.Fprintf(netConn, "Welcome to Sample Chat! You appear to be using Telnet.\nYou must authenticate via email to participate\nEmail: ")
}
m.Unlock()
}
6 years ago
func sendAuthCode(cnf ConfMailer, to string) (string, error) {
code, err := genAuthCode()
if nil != err {
return "", err
}
// TODO use go text templates with HTML escaping
text := "Your authorization code:\n\n" + code
html := "Your authorization code:<br><br>" + code
// https://stackoverflow.com/questions/24493116/how-to-send-a-post-request-in-go
// https://stackoverflow.com/questions/16673766/basic-http-auth-in-go
client := http.Client{}
form := url.Values{}
form.Add("from", cnf.From)
form.Add("to", to)
form.Add("subject", "Sample Chat Auth Code: " + code)
form.Add("text", text)
form.Add("html", html)
req, err := http.NewRequest("POST", cnf.Url, strings.NewReader(form.Encode()))
if nil != err {
return "", err
}
//req.PostForm = form
req.Header.Add("User-Agent", "golang http.Client - Sample Chat App Authenticator")
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("api", cnf.ApiKey)
resp, err := client.Do(req)
if nil != err {
return "", err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if nil != err {
return "", err
}
fmt.Fprintf(os.Stdout, "Here's what Mailgun had to say about the event: %s\n", body)
return code, nil
}
var config Conf
func main() {
flag.Usage = usage
port := flag.Uint("telnet-port", 0, "tcp telnet chat port")
6 years ago
confname := flag.String("conf", "./config.yml", "yaml config file")
flag.Parse()
confstr, err := ioutil.ReadFile(*confname)
6 years ago
fmt.Fprintf(os.Stdout, "-conf=%s\n", *confname)
if nil != err {
fmt.Fprintf(os.Stderr, "%s\nUsing defaults instead\n", err)
confstr = []byte("{\"port\":" + strconv.Itoa(int(*port)) + "}")
}
err = yaml.Unmarshal(confstr, &config)
if nil != err {
config = Conf{}
}
firstMsgs = make(chan myMsg, 128)
6 years ago
myChans = make(map[string](chan myMsg))
newConns = make(chan net.Conn, 128)
newTcpChat = make(chan bufferedConn, 128)
newHttpChat = make(chan bufferedConn, 128)
myRawConns = make(map[net.Conn]bool)
myUnsortedConns = make(map[net.Conn]bool)
6 years ago
// TODO dynamically select on channels?
// https://stackoverflow.com/questions/19992334/how-to-listen-to-n-channels-dynamic-select-statement
//myChans["general"] = make(chan myMsg, 128)
myMsgs = make(chan myMsg, 128)
6 years ago
var addr string
if 0 != int(*port) {
addr = ":" + strconv.Itoa(int(*port))
} else {
addr = ":" + strconv.Itoa(int(config.Port))
}
// https://golang.org/pkg/net/#Conn
sock, err := net.Listen("tcp", addr)
if nil != err {
fmt.Fprintf(os.Stderr, "Couldn't bind to TCP socket %q: %s\n", addr, err)
os.Exit(2)
}
fmt.Println("Listening on", addr);
go func() {
for {
conn, err := sock.Accept()
if err != nil {
// Not sure what kind of error this could be or how it could happen.
// Could a connection abort or end before it's handled?
fmt.Fprintf(os.Stderr, "Error accepting connection:\n%s\n", err)
}
newConns <- conn
}
}()
for {
select {
case conn := <- newConns:
ts := time.Now()
fmt.Fprintf(os.Stdout, "[Handle New Connection] [Timestamp] %s\n", ts)
// This is short lived
go handleConnection(conn)
case bufConn := <- newTcpChat:
myRawConns[bufConn] = true
go handleRaw(bufConn)
case bufConn := <- newHttpChat:
go handleSorted(bufConn)
case bufConn := <- delHttpChat:
bufConn.Close();
delete(myRawConns, bufConn)
//case msg := <- myChans["general"]:
//delete(myChans["general"], bufConn)
case msg := <- myMsgs:
ts, err := msg.receivedAt.MarshalJSON()
if nil != err {
fmt.Fprintf(os.Stderr, "[Error] %s\n", err)
}
fmt.Fprintf(os.Stdout, "[Timestamp] %s\n", ts)
fmt.Fprintf(os.Stdout, "[Remote] %s\n", msg.sender.RemoteAddr().String())
fmt.Fprintf(os.Stdout, "[Message] %s\n", msg.bytes);
for conn, _ := range myRawConns {
if msg.sender == conn {
continue
}
// backlogged connections could prevent a next write,
// so this should be refactored into a goroutine
// And what to do about slow clients that get behind (or DoS)?
// SetDeadTime and Disconnect them?
conn.Write(msg.bytes)
}
case msg := <- firstMsgs:
fmt.Fprintf(os.Stdout, "f [First Message]\n")
ts, err := msg.receivedAt.MarshalJSON()
if nil != err {
fmt.Fprintf(os.Stderr, "f [Error] %s\n", err)
}
fmt.Fprintf(os.Stdout, "f [Timestamp] %s\n", ts)
fmt.Fprintf(os.Stdout, "f [Remote] %s\n", msg.sender.RemoteAddr().String())
fmt.Fprintf(os.Stdout, "f [Message] %s\n", msg.bytes);
}
}
}