Give repeater nicer interface, with methods

This commit is contained in:
Neale Pickett 2020-04-12 17:30:00 -06:00
parent f70c2aaad1
commit d5f6038670
2 changed files with 46 additions and 20 deletions

View File

@ -5,33 +5,47 @@ import (
) )
type Repeater struct { type Repeater struct {
Joins chan io.Writer joins chan io.Writer
Parts chan io.Writer parts chan io.Writer
Sends chan []byte sends chan []byte
subscribers []io.Writer subscribers []io.Writer
} }
func NewRepeater() *Repeater { func NewRepeater() *Repeater {
return &Repeater{ return &Repeater{
Joins: make(chan io.Writer, 5), joins: make(chan io.Writer, 5),
Parts: make(chan io.Writer, 5), parts: make(chan io.Writer, 5),
Sends: make(chan []byte, 5), sends: make(chan []byte, 5),
subscribers: make([]io.Writer, 0, 20), subscribers: make([]io.Writer, 0, 20),
} }
} }
func (r *Repeater) Run() { func (r *Repeater) Join(w io.Writer) {
for { r.joins <- w
r.loop()
}
} }
func (r *Repeater) loop() { func (r *Repeater) Part(w io.Writer) {
r.parts <- w
}
func (r *Repeater) Send(p []byte) {
r.sends <- p
}
func (r *Repeater) Close() {
close(r.sends)
}
func (r *Repeater) Run() {
for r.loop() {}
}
func (r *Repeater) loop() bool {
select { select {
case w := <- r.Joins: case w := <- r.joins:
// Add subscriber // Add subscriber
r.subscribers = append(r.subscribers, w) r.subscribers = append(r.subscribers, w)
case w := <- r.Parts: case w := <- r.parts:
// Remove subscriber // Remove subscriber
for i, s := range r.subscribers { for i, s := range r.subscribers {
if s == w { if s == w {
@ -40,9 +54,13 @@ func (r *Repeater) loop() {
r.subscribers = r.subscribers[:nsubs-1] r.subscribers = r.subscribers[:nsubs-1]
} }
} }
case p := <- r.Sends: case p, ok := <- r.sends:
if ! ok {
return false
}
for _, s := range r.subscribers { for _, s := range r.subscribers {
s.Write(p) s.Write(p)
} }
} }
return true
} }

View File

@ -9,21 +9,21 @@ func TestRepeater(t *testing.T) {
r := NewRepeater() r := NewRepeater()
buf1 := bytes.NewBufferString("buf1") buf1 := bytes.NewBufferString("buf1")
r.Joins <- buf1 r.Join(buf1)
r.loop() r.loop()
if len(r.subscribers) != 1 { if len(r.subscribers) != 1 {
t.Error("Joining did nothing") t.Error("Joining did nothing")
} }
r.Sends <- []byte("moo") r.Send([]byte("moo"))
r.loop() r.loop()
if buf1.String() != "buf1moo" { if buf1.String() != "buf1moo" {
t.Error("Client 1 not repeating", buf1) t.Error("Client 1 not repeating", buf1)
} }
buf2 := bytes.NewBufferString("buf2") buf2 := bytes.NewBufferString("buf2")
r.Joins <- buf2 r.Join(buf2)
r.loop() r.loop()
r.Sends <- []byte("bar") r.Send([]byte("bar"))
r.loop() r.loop()
if buf1.String() != "buf1moobar" { if buf1.String() != "buf1moobar" {
t.Error("Client 1 not repeating", buf1) t.Error("Client 1 not repeating", buf1)
@ -32,9 +32,9 @@ func TestRepeater(t *testing.T) {
t.Error("Client 2 not repeating", buf2) t.Error("Client 2 not repeating", buf2)
} }
r.Parts <- buf1 r.Part(buf1)
r.loop() r.loop()
r.Sends <- []byte("baz") r.Send([]byte("baz"))
r.loop() r.loop()
if buf1.String() != "buf1moobar" { if buf1.String() != "buf1moobar" {
t.Error("Client 1 still getting data after part", buf1) t.Error("Client 1 still getting data after part", buf1)
@ -42,4 +42,12 @@ func TestRepeater(t *testing.T) {
if buf2.String() != "buf2barbaz" { if buf2.String() != "buf2barbaz" {
t.Error("Client 2 not getting data after part", buf2) t.Error("Client 2 not getting data after part", buf2)
} }
r.Close()
if r.loop() {
t.Error("Closed send didn't terminate loop")
}
if r.loop() {
t.Error("Second loop in terminated channel didn't terminate")
}
} }