diff --git a/publisher.go b/publisher.go index e71a9b8..0743889 100644 --- a/publisher.go +++ b/publisher.go @@ -1,24 +1,34 @@ package rabbit import ( + "context" "errors" "log" "time" ) type Publisher interface { - Start(chanLen uint) <-chan []byte + Start(chanLen uint) chan<- []byte } type pubHandler struct { client *Client } -func (p *pubHandler) Start(chanLen uint) <-chan []byte { +func (p *pubHandler) Start(ctx context.Context, chanLen uint) chan<- []byte { ch := make(chan []byte, chanLen) go func() { for { select { + case <-ctx.Done(): + for msg := range ch { + if err := p.push(msg); err != nil { + log.Printf("Error publishing message (shutdown): %s", err) + } + } + log.Println("Publisher stopped") + return + case msg := <-ch: if err := p.push(msg); err != nil { log.Printf("Error publishing message: %s", err)