diff --git a/repeater.go b/repeater.go index 38854c8..0a00af8 100644 --- a/repeater.go +++ b/repeater.go @@ -5,33 +5,47 @@ import ( ) type Repeater struct { - Joins chan io.Writer - Parts chan io.Writer - Sends chan []byte + joins chan io.Writer + parts chan io.Writer + sends chan []byte subscribers []io.Writer } func NewRepeater() *Repeater { return &Repeater{ - Joins: make(chan io.Writer, 5), - Parts: make(chan io.Writer, 5), - Sends: make(chan []byte, 5), + joins: make(chan io.Writer, 5), + parts: make(chan io.Writer, 5), + sends: make(chan []byte, 5), subscribers: make([]io.Writer, 0, 20), } } -func (r *Repeater) Run() { - for { - r.loop() - } +func (r *Repeater) Join(w io.Writer) { + r.joins <- w } -func (r *Repeater) loop() { +func (r *Repeater) Part(w io.Writer) { + r.parts <- w +} + +func (r *Repeater) Send(p []byte) { + r.sends <- p +} + +func (r *Repeater) Close() { + close(r.sends) +} + +func (r *Repeater) Run() { + for r.loop() {} +} + +func (r *Repeater) loop() bool { select { - case w := <- r.Joins: + case w := <- r.joins: // Add subscriber r.subscribers = append(r.subscribers, w) - case w := <- r.Parts: + case w := <- r.parts: // Remove subscriber for i, s := range r.subscribers { if s == w { @@ -40,9 +54,13 @@ func (r *Repeater) loop() { r.subscribers = r.subscribers[:nsubs-1] } } - case p := <- r.Sends: + case p, ok := <- r.sends: + if ! ok { + return false + } for _, s := range r.subscribers { s.Write(p) } } + return true } diff --git a/repeater_test.go b/repeater_test.go index 828b180..87e4657 100644 --- a/repeater_test.go +++ b/repeater_test.go @@ -9,21 +9,21 @@ func TestRepeater(t *testing.T) { r := NewRepeater() buf1 := bytes.NewBufferString("buf1") - r.Joins <- buf1 + r.Join(buf1) r.loop() if len(r.subscribers) != 1 { t.Error("Joining did nothing") } - r.Sends <- []byte("moo") + r.Send([]byte("moo")) r.loop() if buf1.String() != "buf1moo" { t.Error("Client 1 not repeating", buf1) } buf2 := bytes.NewBufferString("buf2") - r.Joins <- buf2 + r.Join(buf2) r.loop() - r.Sends <- []byte("bar") + r.Send([]byte("bar")) r.loop() if buf1.String() != "buf1moobar" { t.Error("Client 1 not repeating", buf1) @@ -32,9 +32,9 @@ func TestRepeater(t *testing.T) { t.Error("Client 2 not repeating", buf2) } - r.Parts <- buf1 + r.Part(buf1) r.loop() - r.Sends <- []byte("baz") + r.Send([]byte("baz")) r.loop() if buf1.String() != "buf1moobar" { t.Error("Client 1 still getting data after part", buf1) @@ -42,4 +42,12 @@ func TestRepeater(t *testing.T) { if buf2.String() != "buf2barbaz" { t.Error("Client 2 not getting data after part", buf2) } + + r.Close() + if r.loop() { + t.Error("Closed send didn't terminate loop") + } + if r.loop() { + t.Error("Second loop in terminated channel didn't terminate") + } }