mirror of https://github.com/nealey/vail.git
Give repeater nicer interface, with methods
This commit is contained in:
parent
f70c2aaad1
commit
d5f6038670
46
repeater.go
46
repeater.go
|
@ -5,33 +5,47 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Repeater struct {
|
type Repeater struct {
|
||||||
Joins chan io.Writer
|
joins chan io.Writer
|
||||||
Parts chan io.Writer
|
parts chan io.Writer
|
||||||
Sends chan []byte
|
sends chan []byte
|
||||||
subscribers []io.Writer
|
subscribers []io.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRepeater() *Repeater {
|
func NewRepeater() *Repeater {
|
||||||
return &Repeater{
|
return &Repeater{
|
||||||
Joins: make(chan io.Writer, 5),
|
joins: make(chan io.Writer, 5),
|
||||||
Parts: make(chan io.Writer, 5),
|
parts: make(chan io.Writer, 5),
|
||||||
Sends: make(chan []byte, 5),
|
sends: make(chan []byte, 5),
|
||||||
subscribers: make([]io.Writer, 0, 20),
|
subscribers: make([]io.Writer, 0, 20),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Repeater) Run() {
|
func (r *Repeater) Join(w io.Writer) {
|
||||||
for {
|
r.joins <- w
|
||||||
r.loop()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Repeater) loop() {
|
func (r *Repeater) Part(w io.Writer) {
|
||||||
|
r.parts <- w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Repeater) Send(p []byte) {
|
||||||
|
r.sends <- p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Repeater) Close() {
|
||||||
|
close(r.sends)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Repeater) Run() {
|
||||||
|
for r.loop() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Repeater) loop() bool {
|
||||||
select {
|
select {
|
||||||
case w := <- r.Joins:
|
case w := <- r.joins:
|
||||||
// Add subscriber
|
// Add subscriber
|
||||||
r.subscribers = append(r.subscribers, w)
|
r.subscribers = append(r.subscribers, w)
|
||||||
case w := <- r.Parts:
|
case w := <- r.parts:
|
||||||
// Remove subscriber
|
// Remove subscriber
|
||||||
for i, s := range r.subscribers {
|
for i, s := range r.subscribers {
|
||||||
if s == w {
|
if s == w {
|
||||||
|
@ -40,9 +54,13 @@ func (r *Repeater) loop() {
|
||||||
r.subscribers = r.subscribers[:nsubs-1]
|
r.subscribers = r.subscribers[:nsubs-1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case p := <- r.Sends:
|
case p, ok := <- r.sends:
|
||||||
|
if ! ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
for _, s := range r.subscribers {
|
for _, s := range r.subscribers {
|
||||||
s.Write(p)
|
s.Write(p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,21 +9,21 @@ func TestRepeater(t *testing.T) {
|
||||||
r := NewRepeater()
|
r := NewRepeater()
|
||||||
|
|
||||||
buf1 := bytes.NewBufferString("buf1")
|
buf1 := bytes.NewBufferString("buf1")
|
||||||
r.Joins <- buf1
|
r.Join(buf1)
|
||||||
r.loop()
|
r.loop()
|
||||||
if len(r.subscribers) != 1 {
|
if len(r.subscribers) != 1 {
|
||||||
t.Error("Joining did nothing")
|
t.Error("Joining did nothing")
|
||||||
}
|
}
|
||||||
r.Sends <- []byte("moo")
|
r.Send([]byte("moo"))
|
||||||
r.loop()
|
r.loop()
|
||||||
if buf1.String() != "buf1moo" {
|
if buf1.String() != "buf1moo" {
|
||||||
t.Error("Client 1 not repeating", buf1)
|
t.Error("Client 1 not repeating", buf1)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf2 := bytes.NewBufferString("buf2")
|
buf2 := bytes.NewBufferString("buf2")
|
||||||
r.Joins <- buf2
|
r.Join(buf2)
|
||||||
r.loop()
|
r.loop()
|
||||||
r.Sends <- []byte("bar")
|
r.Send([]byte("bar"))
|
||||||
r.loop()
|
r.loop()
|
||||||
if buf1.String() != "buf1moobar" {
|
if buf1.String() != "buf1moobar" {
|
||||||
t.Error("Client 1 not repeating", buf1)
|
t.Error("Client 1 not repeating", buf1)
|
||||||
|
@ -32,9 +32,9 @@ func TestRepeater(t *testing.T) {
|
||||||
t.Error("Client 2 not repeating", buf2)
|
t.Error("Client 2 not repeating", buf2)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Parts <- buf1
|
r.Part(buf1)
|
||||||
r.loop()
|
r.loop()
|
||||||
r.Sends <- []byte("baz")
|
r.Send([]byte("baz"))
|
||||||
r.loop()
|
r.loop()
|
||||||
if buf1.String() != "buf1moobar" {
|
if buf1.String() != "buf1moobar" {
|
||||||
t.Error("Client 1 still getting data after part", buf1)
|
t.Error("Client 1 still getting data after part", buf1)
|
||||||
|
@ -42,4 +42,12 @@ func TestRepeater(t *testing.T) {
|
||||||
if buf2.String() != "buf2barbaz" {
|
if buf2.String() != "buf2barbaz" {
|
||||||
t.Error("Client 2 not getting data after part", buf2)
|
t.Error("Client 2 not getting data after part", buf2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.Close()
|
||||||
|
if r.loop() {
|
||||||
|
t.Error("Closed send didn't terminate loop")
|
||||||
|
}
|
||||||
|
if r.loop() {
|
||||||
|
t.Error("Second loop in terminated channel didn't terminate")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue