diff --git a/spongy/logfile.go b/spongy/logfile.go index 1a6e518..57b27be 100644 --- a/spongy/logfile.go +++ b/spongy/logfile.go @@ -1,20 +1,22 @@ -package logfile +package main import ( "fmt" "os" + "path" "time" ) type Logfile struct { + baseDir string file *os.File name string nlines int maxlines int } -func NewLogfile(maxlines int) (*Logfile) { - return &Logfile{nil, "", 0, maxlines} +func NewLogfile(baseDir string, maxlines int) (*Logfile) { + return &Logfile{baseDir, nil, "", 0, maxlines} } func (lf *Logfile) Close() { @@ -34,7 +36,8 @@ func (lf *Logfile) writeln(s string) error { func (lf *Logfile) rotate() error { fn := fmt.Sprintf("%s.log", time.Now().UTC().Format(time.RFC3339)) - newf, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666) + pathn := path.Join(lf.baseDir, "log", fn) + newf, err := os.OpenFile(pathn, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666) if err != nil { return err } @@ -58,8 +61,8 @@ func (lf *Logfile) rotate() error { lf.file = newf // Record symlink to new log - os.Remove("current") - os.Symlink(fn, "current") + os.Remove(path.Join(lf.baseDir, "log", "current")) + os.Symlink(fn, path.Join(lf.baseDir, "log", "current")) logmsg := fmt.Sprintf(". PREVLOG %s", lf.name) lf.writeln(logmsg) diff --git a/spongy/network.go b/spongy/network.go index 94c5fb7..8292e79 100644 --- a/spongy/network.go +++ b/spongy/network.go @@ -4,7 +4,6 @@ import ( "bufio" "crypto/tls" "fmt" - "github.com/nealey/spongy/logfile" "io" "log" "net" @@ -15,8 +14,8 @@ import ( "time" ) -// This gets called a lot. -// So it's easy to fix stuff while running. +// This gets called every time the data's needed. +// That makes it so you can change stuff while running. func ReadLines(fn string) ([]string, error) { lines := make([]string, 0) @@ -47,6 +46,7 @@ type Network struct { Nick string basePath string + serverIndex int conn io.ReadWriteCloser logq chan Message @@ -67,10 +67,11 @@ func NewNetwork(basePath string) *Network { } func (nw *Network) Close() { - nw.conn.Close() + nw.running = false close(nw.logq) - close(nw.inq) - close(nw.outq) + if nw.conn != nil { + nw.conn.Close() + } } func (nw *Network) WatchOutqDirectory() { @@ -114,7 +115,7 @@ func (nw *Network) HandleInfile(fn string) { } func (nw *Network) LogLoop() { - logf := logfile.NewLogfile(int(maxlogsize)) + logf := NewLogfile(nw.basePath, int(maxlogsize)) defer logf.Close() for m := range nw.logq { @@ -130,14 +131,6 @@ func (nw *Network) ServerWriteLoop() { } } -func (nw *Network) ServerReadLoop() { - scanner := bufio.NewScanner(nw.conn) - for scanner.Scan() { - nw.inq <- scanner.Text() - } - close(nw.inq) -} - func (nw *Network) NextNick() { nicks, err := ReadLines(path.Join(nw.basePath, "nick")) if err != nil { @@ -188,7 +181,7 @@ func (nw *Network) MessageDispatch() { switch m.Command { case "PING": - nw.outq <- "PONG: " + m.Text + nw.outq <- "PONG :" + m.Text case "001": nw.JoinChannels() case "433": @@ -197,20 +190,17 @@ func (nw *Network) MessageDispatch() { } } -func (nw *Network) ConnectToServer(server string) bool { - var err error - var name string - - names, err := ReadLines(path.Join(nw.basePath, "name")) +func (nw *Network) ConnectToNextServer() bool { + servers, err := ReadLines(path.Join(nw.basePath, "server")) if err != nil { - me, err := user.Current() - if err != nil { - log.Fatal(err) - } - name = me.Name - } else { - name = names[0] + log.Printf("Couldn't find any servers to connect to in %s", nw.basePath) + return false } + + if nw.serverIndex > len(servers) { + nw.serverIndex = 0 + } + server := servers[nw.serverIndex] switch (server[0]) { case '|': @@ -228,45 +218,57 @@ func (nw *Network) ConnectToServer(server string) bool { if err != nil { log.Print(err) - time.Sleep(2 * time.Second) return false } - fmt.Fprintf(nw.conn, "USER g g g :%s\n", name) - nw.NextNick() - return true } + +func (nw *Network) login() { + var name string + + names, err := ReadLines(path.Join(nw.basePath, "name")) + if err == nil { + name = names[0] + } + if name == "" { + me, err := user.Current() + if err == nil { + name = me.Name + } + } + + if name == "" { + name = "Charlie" + } + + nw.outq <- "USER g g g :" + name + nw.NextNick() +} + func (nw *Network) Connect(){ - serverIndex := 0 for nw.running { - servers, err := ReadLines(path.Join(nw.basePath, "servers")) - if err != nil { - serverIndex = 0 - log.Print(err) - time.Sleep(8) - continue - } - - if serverIndex > len(servers) { - serverIndex = 0 - } - server := servers[serverIndex] - serverIndex += 1 - - if ! nw.ConnectToServer(server) { + if ! nw.ConnectToNextServer() { + time.Sleep(8 * time.Second) continue } nw.inq = make(chan string, 20) nw.outq = make(chan string, 20) - + go nw.ServerWriteLoop() go nw.MessageDispatch() - nw.ServerReadLoop() + + nw.login() + scanner := bufio.NewScanner(nw.conn) + for scanner.Scan() { + nw.inq <- scanner.Text() + } + + close(nw.inq) close(nw.outq) } } diff --git a/spongy/network_test.go b/spongy/network_test.go index f820038..91b1761 100644 --- a/spongy/network_test.go +++ b/spongy/network_test.go @@ -1,9 +1,62 @@ package main import ( + "io/ioutil" + "os" + "path" "testing" + "time" ) -func testConnect(t *testing.T) { +func writeFile(fn string, data string) { + ioutil.WriteFile(fn, []byte(data), os.ModePerm) +} + +func createNetwork(t *testing.T) (base string) { + base, err := ioutil.TempDir("", "spongy-test") + if err != nil { + t.Fatal(err) + } + + writeFile(path.Join(base, "nick"), "spongy_test") + writeFile(path.Join(base, "server"), "moo.slashnet.org:6697") + os.Mkdir(path.Join(base, "outq"), os.ModePerm) + os.Mkdir(path.Join(base, "log"), os.ModePerm) + + return +} + +func TestCreateNetwork(t *testing.T) { + base := createNetwork(t) + + if fi, err := os.Stat(path.Join(base, "nick")); err != nil { + t.Error(err) + } else if fi.IsDir() { + t.Error("%s is not a regular file", path.Join(base, "nick")) + } + + os.RemoveAll(base) + if _, err := os.Stat(path.Join(base, "outq")); err == nil { + t.Error("Didn't unlink outq") + } +} + +func TestConnect(t *testing.T) { + base := createNetwork(t) + defer os.RemoveAll(base) + + n := NewNetwork(base) + go n.Connect() + + time.Sleep(5 * time.Second) + + logBytes, err := ioutil.ReadFile(path.Join(base, "log", "current")) + if err != nil { + n.Close() + t.Fatal(err) + } + t.Log("logBytes: ", logBytes) + + n.Close() return }