Unit test runs to completion, but fails

This commit is contained in:
Neale Pickett 2015-02-10 13:38:21 -07:00
parent 0e29d159f1
commit 29a9268df8
6 changed files with 254 additions and 212 deletions

View File

@ -3,6 +3,7 @@ package main
import ( import (
"strconv" "strconv"
"strings" "strings"
"fmt"
) )
type Message struct { type Message struct {
@ -14,7 +15,7 @@ type Message struct {
Text string Text string
} }
func Parse(v string) (Message, error) { func NewMessage(v string) (Message, error) {
var m Message var m Message
var parts []string var parts []string
var lhs string var lhs string
@ -37,7 +38,7 @@ func Parse(v string) (Message, error) {
m.FullSender = parts[0][1:] m.FullSender = parts[0][1:]
parts = parts[1:] parts = parts[1:]
n, u, _ := nuhost(m.FullSender) n, u, _ := SplitTarget(m.FullSender)
if u != "" { if u != "" {
m.Sender = n m.Sender = n
} }
@ -47,7 +48,7 @@ func Parse(v string) (Message, error) {
switch m.Command { switch m.Command {
case "PRIVMSG", "NOTICE": case "PRIVMSG", "NOTICE":
switch { switch {
case isChannel(parts[1]): case IsChannel(parts[1]):
m.Forum = parts[1] m.Forum = parts[1]
case m.FullSender == ".": case m.FullSender == ".":
m.Forum = parts[1] m.Forum = parts[1]
@ -101,3 +102,32 @@ func (m Message) String() string {
return fmt.Sprintf("%s %s %s %s %s :%s", m.FullSender, m.Command, m.Sender, m.Forum, args, m.Text) return fmt.Sprintf("%s %s %s %s %s :%s", m.FullSender, m.Command, m.Sender, m.Forum, args, m.Text)
} }
func SplitTarget(s string) (string, string, string) {
var parts []string
parts = strings.SplitN(s, "!", 2)
if len(parts) == 1 {
return s, "", ""
}
nick := parts[0]
parts = strings.SplitN(parts[1], "@", 2)
if len(parts) == 1 {
return s, "", ""
}
return nick, parts[0], parts[1]
}
func IsChannel(s string) bool {
if s == "" {
return false
}
switch s[0] {
case '#', '+', '!', '&':
return true
default:
return false
}
}

View File

@ -2,12 +2,14 @@ package main
import ( import (
"bufio" "bufio"
"crypto/tls"
"fmt" "fmt"
"github.com/nealey/spongy/logfile" "github.com/nealey/spongy/logfile"
"io" "io"
"log" "log"
"net" "net"
"os" "os"
"os/user"
"path" "path"
"strings" "strings"
"time" "time"
@ -42,6 +44,8 @@ func ReadLines(fn string) ([]string, error) {
type Network struct { type Network struct {
running bool running bool
Nick string
basePath string basePath string
conn io.ReadWriteCloser conn io.ReadWriteCloser
@ -50,41 +54,27 @@ type Network struct {
outq chan string outq chan string
} }
func NewNetwork(basePath string) (*Network, error) { func NewNetwork(basePath string) *Network {
nicks, err := ReadLines(path.Join(basePath, "nicks")) nw := Network{
if err != nil {
return nil, err
}
gecoses, err := ReadLines(path.Join(basePath, "gecos"))
if err != nil {
return nil, err
}
return &Network{
running: true, running: true,
basePath: basePath, basePath: basePath,
servers: servers,
nicks: nicks,
gecos: gecoses[0],
logq: make(chan Message, 20), logq: make(chan Message, 20),
}, err
go n.LogLoop()
} }
func (n *Network) Close() { go nw.LogLoop()
n.conn.Close()
close(n.logq) return &nw
close(n.inq)
close(n.outq)
} }
func (n *Network) WatchOutqDirectory() { func (nw *Network) Close() {
outqDirname := path.Join(n.basePath, "outq") nw.conn.Close()
close(nw.logq)
close(nw.inq)
close(nw.outq)
}
func (nw *Network) WatchOutqDirectory() {
outqDirname := path.Join(nw.basePath, "outq")
dir, err := os.Open(outqDirname) dir, err := os.Open(outqDirname)
if err != nil { if err != nil {
@ -93,18 +83,18 @@ func (n *Network) WatchOutqDirectory() {
defer dir.Close() defer dir.Close()
// XXX: Do this with fsnotify // XXX: Do this with fsnotify
for n.running { for nw.running {
entities, _ := dir.Readdirnames(0) entities, _ := dir.Readdirnames(0)
for _, fn := range entities { for _, fn := range entities {
pathname := path.Join(outqDirname, fn) pathname := path.Join(outqDirname, fn)
n.HandleInfile(pathname) nw.HandleInfile(pathname)
} }
_, _ = dir.Seek(0, 0) _, _ = dir.Seek(0, 0)
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
} }
} }
func (n *Network) HandleInfile(fn string) { func (nw *Network) HandleInfile(fn string) {
f, err := os.Open(fn) f, err := os.Open(fn)
if err != nil { if err != nil {
return return
@ -119,91 +109,144 @@ func (n *Network) HandleInfile(fn string) {
inf := bufio.NewScanner(f) inf := bufio.NewScanner(f)
for inf.Scan() { for inf.Scan() {
txt := inf.Text() txt := inf.Text()
n.outq <- txt nw.outq <- txt
} }
} }
func (n *Network) LogLoop() { func (nw *Network) LogLoop() {
logf := logfile.NewLogFile(int(maxlogsize)) logf := logfile.NewLogfile(int(maxlogsize))
defer logf.Close() defer logf.Close()
for m := range logq { for m := range nw.logq {
logf.Log(m.String()) logf.Log(m.String())
} }
} }
func (n *Network) ServerWriteLoop() { func (nw *Network) ServerWriteLoop() {
for v := range n.outq { for v := range nw.outq {
m, _ := Parse(v) m, _ := NewMessage(v)
n.logq <- m nw.logq <- m
fmt.Fprintln(n.conn, v) fmt.Fprintln(nw.conn, v)
} }
} }
func (n *Network) ServerReadLoop() { func (nw *Network) ServerReadLoop() {
scanner := bufio.NewScanner(conn) scanner := bufio.NewScanner(nw.conn)
for scanner.Scan() { for scanner.Scan() {
n.inq <- scanner.Text() nw.inq <- scanner.Text()
} }
close(n.inq) close(nw.inq)
} }
func (n *Network) MessageDispatch() { func (nw *Network) NextNick() {
for line := n.inq { nicks, err := ReadLines(path.Join(nw.basePath, "nick"))
if err != nil {
log.Print(err)
return
}
// Make up some alternates if they weren't provided
if len(nicks) == 1 {
nicks = append(nicks, nicks[0] + "_")
nicks = append(nicks, nicks[0] + "__")
nicks = append(nicks, nicks[0] + "___")
}
nextidx := 0
for idx, n := range nicks {
if n == nw.Nick {
nextidx = idx + 1
}
}
nw.Nick = nicks[nextidx % len(nicks)]
nw.outq <- "NICK " + nw.Nick
}
func (nw *Network) JoinChannels() {
chans, err := ReadLines(path.Join(nw.basePath, "channels"))
if err != nil {
log.Print(err)
return
}
for _, ch := range chans {
nw.outq <- "JOIN " + ch
}
}
func (nw *Network) MessageDispatch() {
for line := range nw.inq {
m, err := NewMessage(line) m, err := NewMessage(line)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
continue continue
} }
n.logq <- m nw.logq <- m
// XXX: Add in a handler subprocess call // XXX: Add in a handler subprocess call
switch m.Command { switch m.Command {
case "PING": case "PING":
n.outq <- "PONG: " + m.Text nw.outq <- "PONG: " + m.Text
case "001":
nw.JoinChannels()
case "433": case "433":
nick = nick + "_" nw.NextNick()
outq <- fmt.Sprintf("NICK %s", nick)
} }
} }
} }
func (n *Network) ConnectToServer(server string) bool { func (nw *Network) ConnectToServer(server string) bool {
var err error var err error
var name string
names, err := ReadLines(path.Join(nw.basePath, "name"))
if err != nil {
me, err := user.Current()
if err != nil {
log.Fatal(err)
}
name = me.Name
} else {
name = names[0]
}
switch (server[0]) { switch (server[0]) {
case '|': case '|':
parts := strings.Split(server[1:], " ") parts := strings.Split(server[1:], " ")
n.conn, err = StartStdioProcess(parts[0], parts[1:]) nw.conn, err = StartStdioProcess(parts[0], parts[1:])
case '^': case '^':
n.conn, err = net.Dial("tcp", server[1:]) nw.conn, err = net.Dial("tcp", server[1:])
default: default:
log.Print("Not validating server certificate!") log.Print("Not validating server certificate!")
config := &tls.Config{ config := &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
} }
n.conn, err = tls.Dial("tcp", host, config) nw.conn, err = tls.Dial("tcp", server, config)
} }
if err != nil { if err != nil {
log.Print(err) log.Print(err)
time.sleep(2 * time.Second) time.Sleep(2 * time.Second)
return false return false
} }
fmt.Fprintf(nw.conn, "USER g g g :%s\n", name)
nw.NextNick()
return true return true
} }
func (n *Network) Connect(){ func (nw *Network) Connect(){
serverIndex := 0 serverIndex := 0
for n.running { for nw.running {
servers, err := ReadLines(path.Join(basePath, "servers")) servers, err := ReadLines(path.Join(nw.basePath, "servers"))
if err != nil { if err != nil {
serverIndex := 0 serverIndex = 0
log.Print(err) log.Print(err)
time.sleep(8) time.Sleep(8)
continue continue
} }
@ -213,18 +256,18 @@ func (n *Network) Connect(){
server := servers[serverIndex] server := servers[serverIndex]
serverIndex += 1 serverIndex += 1
if ! n.ConnectToServer(server) { if ! nw.ConnectToServer(server) {
continue continue
} }
n.inq = make(chan string, 20) nw.inq = make(chan string, 20)
n.outq = make(chan string, 20) nw.outq = make(chan string, 20)
go n.ServerWriteLoop() go nw.ServerWriteLoop()
go n.MessageDispatch() go nw.MessageDispatch()
n.ServerReadLoop() nw.ServerReadLoop()
close(n.outq) close(nw.outq)
} }
} }

9
spongy/network_test.go Normal file
View File

@ -0,0 +1,9 @@
package main
import (
"testing"
)
func testConnect(t *testing.T) {
return
}

View File

@ -3,20 +3,25 @@ package main
import ( import (
"io" "io"
"os/exec" "os/exec"
"log"
) )
type ReadWriteCloserWrapper { type ReadWriteCloserWrapper struct {
Reader io.ReadCloser Reader io.ReadCloser
Writer io.WriteCloser Writer io.WriteCloser
cmd *exec.Cmd
} }
def NewReadWriteCloseWrapper(r io.ReadCloser, w io.WriteCloser) *ReadWriteCloserWrapper { func NewReadWriteCloseWrapper(r io.ReadCloser, w io.WriteCloser) *ReadWriteCloserWrapper {
return &ReadWriteCloserWrapper{r, w} return &ReadWriteCloserWrapper{r, w, nil}
} }
def (w *ReadWriteCloserWrapper) Close() (error) { func (w *ReadWriteCloserWrapper) Close() (error) {
err1 := w.Reader.Close() err1 := w.Reader.Close()
err2 := w.Writer.Close() err2 := w.Writer.Close()
if w.cmd != nil{
w.cmd.Wait()
}
switch { switch {
case err1 != nil: case err1 != nil:
@ -27,36 +32,43 @@ def (w *ReadWriteCloserWrapper) Close() (error) {
return nil return nil
} }
def (w *ReadWriteCloserWrapper) Read(p []byte) (n int, err error) { func (w *ReadWriteCloserWrapper) Read(p []byte) (n int, err error) {
n, err := w.Reader.Read(p) n, err = w.Reader.Read(p)
return return
} }
def (w *ReadWriteCloserWrapper) Write(p []byte) (n int, err error) { func (w *ReadWriteCloserWrapper) Write(p []byte) (n int, err error) {
n, err := w.Writer.Write(p) n, err = w.Writer.Write(p)
return return
} }
def StartStdioProcess(name string, args []string) (*ReadWriteCloserWrapper, error) { func StartStdioProcess(name string, args []string) (w *ReadWriteCloserWrapper, err error) {
var w ReadWriteCloserWrapper w = new(ReadWriteCloserWrapper)
cmd := exec.Command(name, args...) cmd := exec.Command(name, args...)
w.Reader, err := cmd.StdoutPipe() if cmd == nil {
log.Fatalf("Can't run command: %v %v", name, args)
}
w.Reader, err = cmd.StdoutPipe()
if err != nil { if err != nil {
return nil, err return nil, err
} }
w.Writer, err := cmd.StdinPipe() w.Writer, err = cmd.StdinPipe()
if err != nil { if err != nil {
w.Reader.Close()
return nil, err return nil, err
} }
if err = cmd.Start(); err != nil { if err = cmd.Start(); err != nil {
w.Reader.Close()
w.Writer.Close()
return nil, err return nil, err
} }
go cmd.Wait() w.cmd = cmd
return &w, nil return
} }

View File

@ -0,0 +1,37 @@
package main
import (
"bytes"
"testing"
)
func TestRWCWCat(t *testing.T) {
proc, err := StartStdioProcess("cat", []string{})
if err != nil {
t.Error(err)
}
out := []byte("Hello, World\n")
p := make([]byte, 0, 50)
n, err := proc.Write(out)
if err != nil {
t.Error(err)
}
if n != len(out) {
t.Errorf("Wrong number of bytes in Write: wanted %d, got %d", len(out), n)
}
n, err = proc.Read(p)
if err != nil {
t.Error(err)
}
if n != len(out) {
t.Errorf("Wrong number of bytes in Read: wanted %d, got %d", len(out), n)
}
if 0 != bytes.Compare(p, out) {
t.Errorf("Mangled read")
}
proc.Close()
}

View File

@ -1,110 +1,65 @@
package main package main
import ( import (
"bufio"
"crypto/tls"
"flag" "flag"
"fmt" "fmt"
"github.com/nealey/spongy/logfile"
"log" "log"
"net"
"os" "os"
"strings" "path"
"time" "time"
) )
var running bool = true var running bool = true
var nick string
var gecos string
var maxlogsize uint var maxlogsize uint
var logq chan Message
func isChannel(s string) bool { func exists(filename string) bool {
if s == "" { _, err := os.Stat(filename); if err != nil {
return false return false
} }
switch s[0] {
case '#', '&', '!', '+', '.', '-':
return true return true
default:
return false
}
} }
func (m Message) String() string { func runsvdir(dirname string) {
args := strings.Join(m.Args, " ") services := make(map[string]*Network)
return fmt.Sprintf("%s %s %s %s %s :%s", m.FullSender, m.Command, m.Sender, m.Forum, args, m.Text)
}
func logLoop() { dir, err := os.Open(dirname)
logf := logfile.NewLogfile(int(maxlogsize))
defer logf.Close()
for m := range logq {
logf.Log(m.String())
}
}
func nuhost(s string) (string, string, string) {
var parts []string
parts = strings.SplitN(s, "!", 2)
if len(parts) == 1 {
return s, "", ""
}
n := parts[0]
parts = strings.SplitN(parts[1], "@", 2)
if len(parts) == 1 {
return s, "", ""
}
return n, parts[0], parts[1]
}
func dispatch(outq chan<- string, m Message) {
logq <- m
switch m.Command {
case "PING":
outq <- "PONG :" + m.Text
case "433":
nick = nick + "_"
outq <- fmt.Sprintf("NICK %s", nick)
}
}
func handleInfile(path string, outq chan<- string) {
f, err := os.Open(path)
if err != nil { if err != nil {
return log.Fatal(err)
}
defer f.Close()
os.Remove(path)
inf := bufio.NewScanner(f)
for inf.Scan() {
txt := inf.Text()
outq <- txt
}
} }
defer dir.Close()
func monitorDirectory(dirname string, dir *os.File, outq chan<- string) {
latest := time.Unix(0, 0)
for running { for running {
fi, err := dir.Stat() dn, err := dir.Readdirnames(0); if err != nil {
if err != nil { log.Fatal(err)
break
} }
current := fi.ModTime()
if current.After(latest) { found := make(map[string]bool)
latest = current
dn, _ := dir.Readdirnames(0)
for _, fn := range dn { for _, fn := range dn {
path := dirname + string(os.PathSeparator) + fn fpath := path.Join(dirname, fn)
handleInfile(path, outq) if exists(path.Join(fpath, "down")) {
continue
} }
if _, ok := services[fpath]; ! ok {
if ! exists(path.Join(fpath, "servers")) {
continue
}
newnet := NewNetwork(fpath)
services[fpath] = newnet
go newnet.Connect()
}
found[fpath] = true
}
// If anything vanished, disconnect it
for fpath, nw := range services {
if _, ok := found[fpath]; ! ok {
nw.Close()
}
}
_, _ = dir.Seek(0, 0) _, _ = dir.Seek(0, 0)
} time.Sleep(20 * time.Second)
time.Sleep(500 * time.Millisecond)
} }
} }
@ -114,51 +69,7 @@ func usage() {
} }
func main() { func main() {
dotls := flag.Bool("notls", true, "Disable TLS security")
outqdir := flag.String("outq", "outq", "Output queue directory")
flag.UintVar(&maxlogsize, "logsize", 1000, "Log entries before rotating") flag.UintVar(&maxlogsize, "logsize", 1000, "Log entries before rotating")
flag.StringVar(&gecos, "gecos", "Bob The Merry Slug", "Gecos entry (full name)")
flag.Parse() flag.Parse()
if flag.NArg() != 2 {
fmt.Fprintln(os.Stderr, "Error: must specify nickname and host")
os.Exit(69)
}
dir, err := os.Open(*outqdir)
if err != nil {
log.Fatal(err)
}
defer dir.Close()
nick := flag.Arg(0)
host := flag.Arg(1)
conn, err := connect(host, *dotls)
if err != nil {
log.Fatal(err)
}
inq := make(chan string)
outq := make(chan string)
logq = make(chan Message)
go logLoop()
go readLoop(conn, inq)
go writeLoop(conn, outq)
go monitorDirectory(*outqdir, dir, outq)
outq <- fmt.Sprintf("NICK %s", nick)
outq <- fmt.Sprintf("USER %s %s %s: %s", nick, nick, nick, gecos)
for v := range inq {
p, err := Parse(v)
if err != nil {
continue
}
dispatch(outq, p)
}
running = false running = false
close(outq)
close(logq)
} }