Unit test runs to completion, but fails

This commit is contained in:
Neale Pickett 2015-02-10 13:38:21 -07:00
parent 0e29d159f1
commit 29a9268df8
6 changed files with 254 additions and 212 deletions

View File

@ -3,6 +3,7 @@ package main
import (
"strconv"
"strings"
"fmt"
)
type Message struct {
@ -14,7 +15,7 @@ type Message struct {
Text string
}
func Parse(v string) (Message, error) {
func NewMessage(v string) (Message, error) {
var m Message
var parts []string
var lhs string
@ -37,7 +38,7 @@ func Parse(v string) (Message, error) {
m.FullSender = parts[0][1:]
parts = parts[1:]
n, u, _ := nuhost(m.FullSender)
n, u, _ := SplitTarget(m.FullSender)
if u != "" {
m.Sender = n
}
@ -47,7 +48,7 @@ func Parse(v string) (Message, error) {
switch m.Command {
case "PRIVMSG", "NOTICE":
switch {
case isChannel(parts[1]):
case IsChannel(parts[1]):
m.Forum = parts[1]
case m.FullSender == ".":
m.Forum = parts[1]
@ -101,3 +102,32 @@ func (m Message) String() string {
return fmt.Sprintf("%s %s %s %s %s :%s", m.FullSender, m.Command, m.Sender, m.Forum, args, m.Text)
}
func SplitTarget(s string) (string, string, string) {
var parts []string
parts = strings.SplitN(s, "!", 2)
if len(parts) == 1 {
return s, "", ""
}
nick := parts[0]
parts = strings.SplitN(parts[1], "@", 2)
if len(parts) == 1 {
return s, "", ""
}
return nick, parts[0], parts[1]
}
func IsChannel(s string) bool {
if s == "" {
return false
}
switch s[0] {
case '#', '+', '!', '&':
return true
default:
return false
}
}

View File

@ -2,12 +2,14 @@ package main
import (
"bufio"
"crypto/tls"
"fmt"
"github.com/nealey/spongy/logfile"
"io"
"log"
"net"
"os"
"os/user"
"path"
"strings"
"time"
@ -42,6 +44,8 @@ func ReadLines(fn string) ([]string, error) {
type Network struct {
running bool
Nick string
basePath string
conn io.ReadWriteCloser
@ -50,41 +54,27 @@ type Network struct {
outq chan string
}
func NewNetwork(basePath string) (*Network, error) {
nicks, err := ReadLines(path.Join(basePath, "nicks"))
if err != nil {
return nil, err
}
gecoses, err := ReadLines(path.Join(basePath, "gecos"))
if err != nil {
return nil, err
}
return &Network{
func NewNetwork(basePath string) *Network {
nw := Network{
running: true,
basePath: basePath,
servers: servers,
nicks: nicks,
gecos: gecoses[0],
logq: make(chan Message, 20),
}, err
}
go n.LogLoop()
go nw.LogLoop()
return &nw
}
func (n *Network) Close() {
n.conn.Close()
close(n.logq)
close(n.inq)
close(n.outq)
func (nw *Network) Close() {
nw.conn.Close()
close(nw.logq)
close(nw.inq)
close(nw.outq)
}
func (n *Network) WatchOutqDirectory() {
outqDirname := path.Join(n.basePath, "outq")
func (nw *Network) WatchOutqDirectory() {
outqDirname := path.Join(nw.basePath, "outq")
dir, err := os.Open(outqDirname)
if err != nil {
@ -93,18 +83,18 @@ func (n *Network) WatchOutqDirectory() {
defer dir.Close()
// XXX: Do this with fsnotify
for n.running {
for nw.running {
entities, _ := dir.Readdirnames(0)
for _, fn := range entities {
pathname := path.Join(outqDirname, fn)
n.HandleInfile(pathname)
nw.HandleInfile(pathname)
}
_, _ = dir.Seek(0, 0)
time.Sleep(500 * time.Millisecond)
}
}
func (n *Network) HandleInfile(fn string) {
func (nw *Network) HandleInfile(fn string) {
f, err := os.Open(fn)
if err != nil {
return
@ -119,91 +109,144 @@ func (n *Network) HandleInfile(fn string) {
inf := bufio.NewScanner(f)
for inf.Scan() {
txt := inf.Text()
n.outq <- txt
nw.outq <- txt
}
}
func (n *Network) LogLoop() {
logf := logfile.NewLogFile(int(maxlogsize))
func (nw *Network) LogLoop() {
logf := logfile.NewLogfile(int(maxlogsize))
defer logf.Close()
for m := range logq {
for m := range nw.logq {
logf.Log(m.String())
}
}
func (n *Network) ServerWriteLoop() {
for v := range n.outq {
m, _ := Parse(v)
n.logq <- m
fmt.Fprintln(n.conn, v)
func (nw *Network) ServerWriteLoop() {
for v := range nw.outq {
m, _ := NewMessage(v)
nw.logq <- m
fmt.Fprintln(nw.conn, v)
}
}
func (n *Network) ServerReadLoop() {
scanner := bufio.NewScanner(conn)
func (nw *Network) ServerReadLoop() {
scanner := bufio.NewScanner(nw.conn)
for scanner.Scan() {
n.inq <- scanner.Text()
nw.inq <- scanner.Text()
}
close(n.inq)
close(nw.inq)
}
func (n *Network) MessageDispatch() {
for line := n.inq {
func (nw *Network) NextNick() {
nicks, err := ReadLines(path.Join(nw.basePath, "nick"))
if err != nil {
log.Print(err)
return
}
// Make up some alternates if they weren't provided
if len(nicks) == 1 {
nicks = append(nicks, nicks[0] + "_")
nicks = append(nicks, nicks[0] + "__")
nicks = append(nicks, nicks[0] + "___")
}
nextidx := 0
for idx, n := range nicks {
if n == nw.Nick {
nextidx = idx + 1
}
}
nw.Nick = nicks[nextidx % len(nicks)]
nw.outq <- "NICK " + nw.Nick
}
func (nw *Network) JoinChannels() {
chans, err := ReadLines(path.Join(nw.basePath, "channels"))
if err != nil {
log.Print(err)
return
}
for _, ch := range chans {
nw.outq <- "JOIN " + ch
}
}
func (nw *Network) MessageDispatch() {
for line := range nw.inq {
m, err := NewMessage(line)
if err != nil {
log.Print(err)
continue
}
n.logq <- m
nw.logq <- m
// XXX: Add in a handler subprocess call
switch m.Command {
case "PING":
n.outq <- "PONG: " + m.Text
nw.outq <- "PONG: " + m.Text
case "001":
nw.JoinChannels()
case "433":
nick = nick + "_"
outq <- fmt.Sprintf("NICK %s", nick)
nw.NextNick()
}
}
}
func (n *Network) ConnectToServer(server string) bool {
func (nw *Network) ConnectToServer(server string) bool {
var err error
var name string
names, err := ReadLines(path.Join(nw.basePath, "name"))
if err != nil {
me, err := user.Current()
if err != nil {
log.Fatal(err)
}
name = me.Name
} else {
name = names[0]
}
switch (server[0]) {
case '|':
parts := strings.Split(server[1:], " ")
n.conn, err = StartStdioProcess(parts[0], parts[1:])
nw.conn, err = StartStdioProcess(parts[0], parts[1:])
case '^':
n.conn, err = net.Dial("tcp", server[1:])
nw.conn, err = net.Dial("tcp", server[1:])
default:
log.Print("Not validating server certificate!")
config := &tls.Config{
InsecureSkipVerify: true,
}
n.conn, err = tls.Dial("tcp", host, config)
nw.conn, err = tls.Dial("tcp", server, config)
}
if err != nil {
log.Print(err)
time.sleep(2 * time.Second)
time.Sleep(2 * time.Second)
return false
}
fmt.Fprintf(nw.conn, "USER g g g :%s\n", name)
nw.NextNick()
return true
}
func (n *Network) Connect(){
func (nw *Network) Connect(){
serverIndex := 0
for n.running {
servers, err := ReadLines(path.Join(basePath, "servers"))
for nw.running {
servers, err := ReadLines(path.Join(nw.basePath, "servers"))
if err != nil {
serverIndex := 0
serverIndex = 0
log.Print(err)
time.sleep(8)
time.Sleep(8)
continue
}
@ -213,18 +256,18 @@ func (n *Network) Connect(){
server := servers[serverIndex]
serverIndex += 1
if ! n.ConnectToServer(server) {
if ! nw.ConnectToServer(server) {
continue
}
n.inq = make(chan string, 20)
n.outq = make(chan string, 20)
nw.inq = make(chan string, 20)
nw.outq = make(chan string, 20)
go n.ServerWriteLoop()
go n.MessageDispatch()
n.ServerReadLoop()
go nw.ServerWriteLoop()
go nw.MessageDispatch()
nw.ServerReadLoop()
close(n.outq)
close(nw.outq)
}
}

9
spongy/network_test.go Normal file
View File

@ -0,0 +1,9 @@
package main
import (
"testing"
)
func testConnect(t *testing.T) {
return
}

View File

@ -3,20 +3,25 @@ package main
import (
"io"
"os/exec"
"log"
)
type ReadWriteCloserWrapper {
type ReadWriteCloserWrapper struct {
Reader io.ReadCloser
Writer io.WriteCloser
cmd *exec.Cmd
}
def NewReadWriteCloseWrapper(r io.ReadCloser, w io.WriteCloser) *ReadWriteCloserWrapper {
return &ReadWriteCloserWrapper{r, w}
func NewReadWriteCloseWrapper(r io.ReadCloser, w io.WriteCloser) *ReadWriteCloserWrapper {
return &ReadWriteCloserWrapper{r, w, nil}
}
def (w *ReadWriteCloserWrapper) Close() (error) {
func (w *ReadWriteCloserWrapper) Close() (error) {
err1 := w.Reader.Close()
err2 := w.Writer.Close()
if w.cmd != nil{
w.cmd.Wait()
}
switch {
case err1 != nil:
@ -27,36 +32,43 @@ def (w *ReadWriteCloserWrapper) Close() (error) {
return nil
}
def (w *ReadWriteCloserWrapper) Read(p []byte) (n int, err error) {
n, err := w.Reader.Read(p)
func (w *ReadWriteCloserWrapper) Read(p []byte) (n int, err error) {
n, err = w.Reader.Read(p)
return
}
def (w *ReadWriteCloserWrapper) Write(p []byte) (n int, err error) {
n, err := w.Writer.Write(p)
func (w *ReadWriteCloserWrapper) Write(p []byte) (n int, err error) {
n, err = w.Writer.Write(p)
return
}
def StartStdioProcess(name string, args []string) (*ReadWriteCloserWrapper, error) {
var w ReadWriteCloserWrapper
func StartStdioProcess(name string, args []string) (w *ReadWriteCloserWrapper, err error) {
w = new(ReadWriteCloserWrapper)
cmd := exec.Command(name, args...)
w.Reader, err := cmd.StdoutPipe()
if cmd == nil {
log.Fatalf("Can't run command: %v %v", name, args)
}
w.Reader, err = cmd.StdoutPipe()
if err != nil {
return nil, err
}
w.Writer, err := cmd.StdinPipe()
w.Writer, err = cmd.StdinPipe()
if err != nil {
w.Reader.Close()
return nil, err
}
if err = cmd.Start(); err != nil {
w.Reader.Close()
w.Writer.Close()
return nil, err
}
go cmd.Wait()
w.cmd = cmd
return &w, nil
return
}

View File

@ -0,0 +1,37 @@
package main
import (
"bytes"
"testing"
)
func TestRWCWCat(t *testing.T) {
proc, err := StartStdioProcess("cat", []string{})
if err != nil {
t.Error(err)
}
out := []byte("Hello, World\n")
p := make([]byte, 0, 50)
n, err := proc.Write(out)
if err != nil {
t.Error(err)
}
if n != len(out) {
t.Errorf("Wrong number of bytes in Write: wanted %d, got %d", len(out), n)
}
n, err = proc.Read(p)
if err != nil {
t.Error(err)
}
if n != len(out) {
t.Errorf("Wrong number of bytes in Read: wanted %d, got %d", len(out), n)
}
if 0 != bytes.Compare(p, out) {
t.Errorf("Mangled read")
}
proc.Close()
}

View File

@ -1,110 +1,65 @@
package main
import (
"bufio"
"crypto/tls"
"flag"
"fmt"
"github.com/nealey/spongy/logfile"
"log"
"net"
"os"
"strings"
"path"
"time"
)
var running bool = true
var nick string
var gecos string
var maxlogsize uint
var logq chan Message
func isChannel(s string) bool {
if s == "" {
func exists(filename string) bool {
_, err := os.Stat(filename); if err != nil {
return false
}
switch s[0] {
case '#', '&', '!', '+', '.', '-':
return true
default:
return false
}
}
func (m Message) String() string {
args := strings.Join(m.Args, " ")
return fmt.Sprintf("%s %s %s %s %s :%s", m.FullSender, m.Command, m.Sender, m.Forum, args, m.Text)
}
func runsvdir(dirname string) {
services := make(map[string]*Network)
func logLoop() {
logf := logfile.NewLogfile(int(maxlogsize))
defer logf.Close()
for m := range logq {
logf.Log(m.String())
}
}
func nuhost(s string) (string, string, string) {
var parts []string
parts = strings.SplitN(s, "!", 2)
if len(parts) == 1 {
return s, "", ""
}
n := parts[0]
parts = strings.SplitN(parts[1], "@", 2)
if len(parts) == 1 {
return s, "", ""
}
return n, parts[0], parts[1]
}
func dispatch(outq chan<- string, m Message) {
logq <- m
switch m.Command {
case "PING":
outq <- "PONG :" + m.Text
case "433":
nick = nick + "_"
outq <- fmt.Sprintf("NICK %s", nick)
}
}
func handleInfile(path string, outq chan<- string) {
f, err := os.Open(path)
dir, err := os.Open(dirname)
if err != nil {
return
log.Fatal(err)
}
defer f.Close()
os.Remove(path)
inf := bufio.NewScanner(f)
for inf.Scan() {
txt := inf.Text()
outq <- txt
}
}
defer dir.Close()
func monitorDirectory(dirname string, dir *os.File, outq chan<- string) {
latest := time.Unix(0, 0)
for running {
fi, err := dir.Stat()
if err != nil {
break
dn, err := dir.Readdirnames(0); if err != nil {
log.Fatal(err)
}
current := fi.ModTime()
if current.After(latest) {
latest = current
dn, _ := dir.Readdirnames(0)
found := make(map[string]bool)
for _, fn := range dn {
path := dirname + string(os.PathSeparator) + fn
handleInfile(path, outq)
fpath := path.Join(dirname, fn)
if exists(path.Join(fpath, "down")) {
continue
}
if _, ok := services[fpath]; ! ok {
if ! exists(path.Join(fpath, "servers")) {
continue
}
newnet := NewNetwork(fpath)
services[fpath] = newnet
go newnet.Connect()
}
found[fpath] = true
}
// If anything vanished, disconnect it
for fpath, nw := range services {
if _, ok := found[fpath]; ! ok {
nw.Close()
}
}
_, _ = dir.Seek(0, 0)
}
time.Sleep(500 * time.Millisecond)
time.Sleep(20 * time.Second)
}
}
@ -114,51 +69,7 @@ func usage() {
}
func main() {
dotls := flag.Bool("notls", true, "Disable TLS security")
outqdir := flag.String("outq", "outq", "Output queue directory")
flag.UintVar(&maxlogsize, "logsize", 1000, "Log entries before rotating")
flag.StringVar(&gecos, "gecos", "Bob The Merry Slug", "Gecos entry (full name)")
flag.Parse()
if flag.NArg() != 2 {
fmt.Fprintln(os.Stderr, "Error: must specify nickname and host")
os.Exit(69)
}
dir, err := os.Open(*outqdir)
if err != nil {
log.Fatal(err)
}
defer dir.Close()
nick := flag.Arg(0)
host := flag.Arg(1)
conn, err := connect(host, *dotls)
if err != nil {
log.Fatal(err)
}
inq := make(chan string)
outq := make(chan string)
logq = make(chan Message)
go logLoop()
go readLoop(conn, inq)
go writeLoop(conn, outq)
go monitorDirectory(*outqdir, dir, outq)
outq <- fmt.Sprintf("NICK %s", nick)
outq <- fmt.Sprintf("USER %s %s %s: %s", nick, nick, nick, gecos)
for v := range inq {
p, err := Parse(v)
if err != nil {
continue
}
dispatch(outq, p)
}
running = false
close(outq)
close(logq)
}