Give repeater nicer interface, with methods

This commit is contained in:
Neale Pickett 2020-04-12 17:30:00 -06:00
parent f70c2aaad1
commit d5f6038670
2 changed files with 46 additions and 20 deletions

View File

@ -5,33 +5,47 @@ import (
)
type Repeater struct {
Joins chan io.Writer
Parts chan io.Writer
Sends chan []byte
joins chan io.Writer
parts chan io.Writer
sends chan []byte
subscribers []io.Writer
}
func NewRepeater() *Repeater {
return &Repeater{
Joins: make(chan io.Writer, 5),
Parts: make(chan io.Writer, 5),
Sends: make(chan []byte, 5),
joins: make(chan io.Writer, 5),
parts: make(chan io.Writer, 5),
sends: make(chan []byte, 5),
subscribers: make([]io.Writer, 0, 20),
}
}
func (r *Repeater) Run() {
for {
r.loop()
}
func (r *Repeater) Join(w io.Writer) {
r.joins <- w
}
func (r *Repeater) loop() {
func (r *Repeater) Part(w io.Writer) {
r.parts <- w
}
func (r *Repeater) Send(p []byte) {
r.sends <- p
}
func (r *Repeater) Close() {
close(r.sends)
}
func (r *Repeater) Run() {
for r.loop() {}
}
func (r *Repeater) loop() bool {
select {
case w := <- r.Joins:
case w := <- r.joins:
// Add subscriber
r.subscribers = append(r.subscribers, w)
case w := <- r.Parts:
case w := <- r.parts:
// Remove subscriber
for i, s := range r.subscribers {
if s == w {
@ -40,9 +54,13 @@ func (r *Repeater) loop() {
r.subscribers = r.subscribers[:nsubs-1]
}
}
case p := <- r.Sends:
case p, ok := <- r.sends:
if ! ok {
return false
}
for _, s := range r.subscribers {
s.Write(p)
}
}
return true
}

View File

@ -9,21 +9,21 @@ func TestRepeater(t *testing.T) {
r := NewRepeater()
buf1 := bytes.NewBufferString("buf1")
r.Joins <- buf1
r.Join(buf1)
r.loop()
if len(r.subscribers) != 1 {
t.Error("Joining did nothing")
}
r.Sends <- []byte("moo")
r.Send([]byte("moo"))
r.loop()
if buf1.String() != "buf1moo" {
t.Error("Client 1 not repeating", buf1)
}
buf2 := bytes.NewBufferString("buf2")
r.Joins <- buf2
r.Join(buf2)
r.loop()
r.Sends <- []byte("bar")
r.Send([]byte("bar"))
r.loop()
if buf1.String() != "buf1moobar" {
t.Error("Client 1 not repeating", buf1)
@ -32,9 +32,9 @@ func TestRepeater(t *testing.T) {
t.Error("Client 2 not repeating", buf2)
}
r.Parts <- buf1
r.Part(buf1)
r.loop()
r.Sends <- []byte("baz")
r.Send([]byte("baz"))
r.loop()
if buf1.String() != "buf1moobar" {
t.Error("Client 1 still getting data after part", buf1)
@ -42,4 +42,12 @@ func TestRepeater(t *testing.T) {
if buf2.String() != "buf2barbaz" {
t.Error("Client 2 not getting data after part", buf2)
}
r.Close()
if r.loop() {
t.Error("Closed send didn't terminate loop")
}
if r.loop() {
t.Error("Second loop in terminated channel didn't terminate")
}
}