mirror of https://github.com/nealey/vail.git
Give repeater nicer interface, with methods
This commit is contained in:
parent
f70c2aaad1
commit
d5f6038670
46
repeater.go
46
repeater.go
|
@ -5,33 +5,47 @@ import (
|
|||
)
|
||||
|
||||
type Repeater struct {
|
||||
Joins chan io.Writer
|
||||
Parts chan io.Writer
|
||||
Sends chan []byte
|
||||
joins chan io.Writer
|
||||
parts chan io.Writer
|
||||
sends chan []byte
|
||||
subscribers []io.Writer
|
||||
}
|
||||
|
||||
func NewRepeater() *Repeater {
|
||||
return &Repeater{
|
||||
Joins: make(chan io.Writer, 5),
|
||||
Parts: make(chan io.Writer, 5),
|
||||
Sends: make(chan []byte, 5),
|
||||
joins: make(chan io.Writer, 5),
|
||||
parts: make(chan io.Writer, 5),
|
||||
sends: make(chan []byte, 5),
|
||||
subscribers: make([]io.Writer, 0, 20),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Repeater) Run() {
|
||||
for {
|
||||
r.loop()
|
||||
}
|
||||
func (r *Repeater) Join(w io.Writer) {
|
||||
r.joins <- w
|
||||
}
|
||||
|
||||
func (r *Repeater) loop() {
|
||||
func (r *Repeater) Part(w io.Writer) {
|
||||
r.parts <- w
|
||||
}
|
||||
|
||||
func (r *Repeater) Send(p []byte) {
|
||||
r.sends <- p
|
||||
}
|
||||
|
||||
func (r *Repeater) Close() {
|
||||
close(r.sends)
|
||||
}
|
||||
|
||||
func (r *Repeater) Run() {
|
||||
for r.loop() {}
|
||||
}
|
||||
|
||||
func (r *Repeater) loop() bool {
|
||||
select {
|
||||
case w := <- r.Joins:
|
||||
case w := <- r.joins:
|
||||
// Add subscriber
|
||||
r.subscribers = append(r.subscribers, w)
|
||||
case w := <- r.Parts:
|
||||
case w := <- r.parts:
|
||||
// Remove subscriber
|
||||
for i, s := range r.subscribers {
|
||||
if s == w {
|
||||
|
@ -40,9 +54,13 @@ func (r *Repeater) loop() {
|
|||
r.subscribers = r.subscribers[:nsubs-1]
|
||||
}
|
||||
}
|
||||
case p := <- r.Sends:
|
||||
case p, ok := <- r.sends:
|
||||
if ! ok {
|
||||
return false
|
||||
}
|
||||
for _, s := range r.subscribers {
|
||||
s.Write(p)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -9,21 +9,21 @@ func TestRepeater(t *testing.T) {
|
|||
r := NewRepeater()
|
||||
|
||||
buf1 := bytes.NewBufferString("buf1")
|
||||
r.Joins <- buf1
|
||||
r.Join(buf1)
|
||||
r.loop()
|
||||
if len(r.subscribers) != 1 {
|
||||
t.Error("Joining did nothing")
|
||||
}
|
||||
r.Sends <- []byte("moo")
|
||||
r.Send([]byte("moo"))
|
||||
r.loop()
|
||||
if buf1.String() != "buf1moo" {
|
||||
t.Error("Client 1 not repeating", buf1)
|
||||
}
|
||||
|
||||
buf2 := bytes.NewBufferString("buf2")
|
||||
r.Joins <- buf2
|
||||
r.Join(buf2)
|
||||
r.loop()
|
||||
r.Sends <- []byte("bar")
|
||||
r.Send([]byte("bar"))
|
||||
r.loop()
|
||||
if buf1.String() != "buf1moobar" {
|
||||
t.Error("Client 1 not repeating", buf1)
|
||||
|
@ -32,9 +32,9 @@ func TestRepeater(t *testing.T) {
|
|||
t.Error("Client 2 not repeating", buf2)
|
||||
}
|
||||
|
||||
r.Parts <- buf1
|
||||
r.Part(buf1)
|
||||
r.loop()
|
||||
r.Sends <- []byte("baz")
|
||||
r.Send([]byte("baz"))
|
||||
r.loop()
|
||||
if buf1.String() != "buf1moobar" {
|
||||
t.Error("Client 1 still getting data after part", buf1)
|
||||
|
@ -42,4 +42,12 @@ func TestRepeater(t *testing.T) {
|
|||
if buf2.String() != "buf2barbaz" {
|
||||
t.Error("Client 2 not getting data after part", buf2)
|
||||
}
|
||||
|
||||
r.Close()
|
||||
if r.loop() {
|
||||
t.Error("Closed send didn't terminate loop")
|
||||
}
|
||||
if r.loop() {
|
||||
t.Error("Second loop in terminated channel didn't terminate")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue