Compare commits

..

No commits in common. "main" and "v0.1.6" have entirely different histories.
main ... v0.1.6

7 changed files with 50 additions and 232 deletions

View file

@ -1,34 +0,0 @@
package rabbit
import (
"errors"
"fmt"
"net/url"
)
type Address struct {
Username string
Password string
Host string
Port uint16
Vhost string
}
func (a *Address) makeAddr() (string, error) {
if a.Host == "" {
return "", errors.New("no host provided")
}
if a.Vhost == "" {
return "", errors.New("no vhost provided")
}
//"amqp://username:password@host:port/vhost"
return fmt.Sprintf("amqp://%v:%v@%v:%v/%v",
url.QueryEscape(a.Username),
url.QueryEscape(a.Password),
a.Host,
a.Port,
url.PathEscape(a.Vhost),
), nil
}

View file

@ -4,20 +4,22 @@ import (
"context" "context"
amqp "github.com/rabbitmq/amqp091-go" amqp "github.com/rabbitmq/amqp091-go"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"time" "log"
) )
type Consumer interface { type Consumer interface {
Start(ctx context.Context, chanLen uint) <-chan []byte Start() chan []byte
} }
type consumeHandler struct { type consumeHandler struct {
client *Client ctx context.Context
client *Client
chanLen int
} }
func (c *consumeHandler) Start(ctx context.Context, chanLen uint) <-chan []byte { func (c *consumeHandler) Start() chan []byte {
msgCh := make(chan []byte, chanLen) msgCh := make(chan []byte, c.chanLen)
go runConsumer(ctx, c.client, msgCh) go runConsumer(c.ctx, c.client, msgCh)
return msgCh return msgCh
} }
@ -28,18 +30,9 @@ func runConsumer(ctx context.Context, client *Client, msgCh chan []byte) {
limiter := rate.NewLimiter(rate.Every(client.opts.consumerRateLimit), client.opts.consumerBurstSize) limiter := rate.NewLimiter(rate.Every(client.opts.consumerRateLimit), client.opts.consumerBurstSize)
reconnectSignal := make(chan struct{}, 1)
reconnectTrigger := func() {
select {
case reconnectSignal <- struct{}{}:
default:
}
}
// initial consume
deliveries, err := client.consume() deliveries, err := client.consume()
if err != nil { if err != nil {
client.logger.Printf("Could not start consuming: %s\n", err) log.Printf("Could not start consuming: %s\n", err)
return return
} }
@ -48,66 +41,33 @@ func runConsumer(ctx context.Context, client *Client, msgCh chan []byte) {
for { for {
select { select {
case <-runCtx.Done(): case <-runCtx.Done():
err = client.Close() err = client.Close()
if err != nil { if err != nil {
client.logger.Printf("Close failed: %s\n", err) log.Printf("Close failed: %s\n", err)
} }
return return
case <-reconnectSignal: case amqErr := <-chClosedCh:
select { log.Printf("AMQP Channel closed due to: %s\n", amqErr)
case <-runCtx.Done():
return
case <-time.After(client.opts.reconnectDelay):
}
deliveries, err = client.consume() deliveries, err = client.consume()
if err != nil { if err != nil {
client.logger.Printf("Consume reconnect failed, retry: %s\n", err) log.Println("Error trying to consume, will try again")
reconnectTrigger()
continue continue
} }
// re-create closing chan
chClosedCh = make(chan *amqp.Error, 1) chClosedCh = make(chan *amqp.Error, 1)
client.Channel.NotifyClose(chClosedCh) client.Channel.NotifyClose(chClosedCh)
client.logger.Println("Consume reconnect success")
case amqErr := <-chClosedCh: case delivery := <-deliveries:
client.logger.Printf("AMQP Channel closed due to: %s Reconnecting...\n", amqErr) msgCh <- delivery.Body
reconnectTrigger() log.Printf("Received message: %s\n", delivery.Body)
case delivery, ok := <-deliveries: if err = delivery.Ack(false); err != nil {
if !ok { log.Printf("Error acknowledging message: %s\n", err)
client.logger.Println("Deliveries channel closed unexpectedly")
reconnectTrigger()
continue
}
if err = limiter.Wait(runCtx); err != nil {
client.logger.Printf("Wait limiter failed: %s\n", err)
}
select {
case <-runCtx.Done():
if err = delivery.Nack(false, true); err != nil {
client.logger.Printf("Error nacking message: %s\n", err)
}
err = client.Close()
if err != nil {
client.logger.Printf("Close failed: %s\n", err)
}
return
case msgCh <- delivery.Body:
client.logger.Printf("Received message: %s\n", delivery.Body)
if err = delivery.Ack(false); err != nil {
client.logger.Printf("Error acknowledging message: %s\n", err)
}
} }
limiter.Wait(runCtx)
} }
} }
} }

View file

@ -1,10 +1,8 @@
package rabbit package rabbit
import ( import (
"errors" "context"
"fmt"
amqp "github.com/rabbitmq/amqp091-go" amqp "github.com/rabbitmq/amqp091-go"
"io"
"log" "log"
"os" "os"
"sync" "sync"
@ -13,7 +11,7 @@ import (
type Client struct { type Client struct {
mutex *sync.Mutex mutex *sync.Mutex
queueOptions *QueueOpts queueName string
logger *log.Logger logger *log.Logger
connection *amqp.Connection connection *amqp.Connection
Channel *amqp.Channel Channel *amqp.Channel
@ -23,68 +21,44 @@ type Client struct {
notifyConfirm chan amqp.Confirmation notifyConfirm chan amqp.Confirmation
isReady bool isReady bool
opts options opts options
connected chan struct{}
} }
type options struct { type options struct {
connectTimeout time.Duration
reconnectDelay time.Duration reconnectDelay time.Duration
reInitDelay time.Duration reInitDelay time.Duration
resendDelay time.Duration resendDelay time.Duration
consumerRateLimit time.Duration consumerRateLimit time.Duration
consumerBurstSize int consumerBurstSize int
logger *log.Logger
} }
type QueueOpts struct { func NewClient(addr, queueName string, opts ...Option) *Client {
QueueName string if addr == "" {
Durable bool log.Fatal(errNoAddr)
AutoDelete bool
Exclusive bool
NoWait bool
Args amqp.Table
}
func NewClient(address Address, queueOpts QueueOpts, opts ...Option) (*Client, error) {
l := log.New(os.Stdout, "", log.LstdFlags)
addr, err := address.makeAddr()
if err != nil {
return nil, errors.Join(errBadAddr, err)
} }
if queueOpts.QueueName == "" { if queueName == "" {
l.Fatal(errNoQueue) log.Fatal(errNoQueue)
} }
client := Client{ client := Client{
mutex: &sync.Mutex{}, mutex: &sync.Mutex{},
queueOptions: &queueOpts, logger: log.New(os.Stdout, "", log.LstdFlags),
done: make(chan bool), queueName: queueName,
connected: make(chan struct{}), done: make(chan bool),
logger: l,
} }
o := options{ o := options{
connectTimeout: 15 * time.Second, reconnectDelay: 5,
reconnectDelay: 5 * time.Second, reInitDelay: 2,
reInitDelay: 2 * time.Second, resendDelay: 5,
resendDelay: 5 * time.Second,
consumerRateLimit: time.Millisecond * 500, consumerRateLimit: time.Millisecond * 500,
consumerBurstSize: 10, consumerBurstSize: 10,
logger: log.New(io.Discard, "", 0),
} }
for _, opt := range opts { for _, opt := range opts {
opt(&o) opt(&o)
} }
client.logger = o.logger
if err = client.connectAndSignal(addr, o.connectTimeout); err != nil {
return nil, fmt.Errorf("failed to connect: %w", err)
}
go client.handleReconnect(addr) go client.handleReconnect(addr)
return &client, nil return &client, nil
@ -94,37 +68,6 @@ func NewPublisher(client *Client) Publisher {
return &pubHandler{client: client} return &pubHandler{client: client}
} }
func NewConsumer(client *Client) Consumer { func NewConsumer(ctx context.Context, client *Client, chanLen int) Consumer {
return &consumeHandler{client: client} return &consumeHandler{ctx: ctx, client: client, chanLen: chanLen}
}
func (c *Client) connectAndSignal(addr string, timeout time.Duration) error {
type result struct {
conn *amqp.Connection
err error
}
resCh := make(chan result, 1)
go func() {
conn, err := amqp.Dial(addr)
resCh <- result{conn, err}
}()
select {
case <-time.After(timeout):
return fmt.Errorf("connection timeout after %v", timeout)
case res := <-resCh:
if res.err != nil {
return res.err
}
c.changeConnection(res.conn)
if err := c.init(res.conn); err != nil {
res.conn.Close()
return fmt.Errorf("init failed: %w", err)
}
close(c.connected)
return nil
}
} }

View file

@ -3,7 +3,7 @@ package rabbit
import "errors" import "errors"
var ( var (
errBadAddr = errors.New("bad address") errNoAddr = errors.New("no address")
errNoQueue = errors.New("no queue") errNoQueue = errors.New("no queue")
errNotConnected = errors.New("not connected to a server") errNotConnected = errors.New("not connected to a server")
errAlreadyClosed = errors.New("already closed: not connected to the server") errAlreadyClosed = errors.New("already closed: not connected to the server")

View file

@ -1,20 +1,9 @@
package rabbit package rabbit
import ( import "time"
"io"
"log"
"os"
"time"
)
type Option func(*options) type Option func(*options)
func WithConnectTimeout(t time.Duration) Option {
return func(op *options) {
op.connectTimeout = t
}
}
func WithReconnectDelay(t time.Duration) Option { func WithReconnectDelay(t time.Duration) Option {
return func(op *options) { return func(op *options) {
op.reconnectDelay = t op.reconnectDelay = t
@ -44,17 +33,3 @@ func WithConsumerBurstSize(t int) Option {
op.consumerBurstSize = t op.consumerBurstSize = t
} }
} }
func WithLogger(l *log.Logger) Option {
return func(op *options) { op.logger = l }
}
func WithLogging(enabled bool) Option {
return func(op *options) {
if enabled {
op.logger = log.New(os.Stdout, "", log.LstdFlags)
} else {
op.logger = log.New(io.Discard, "", 0)
}
}
}

View file

@ -1,44 +1,19 @@
package rabbit package rabbit
import ( import (
"context"
"errors" "errors"
"time" "time"
) )
type Publisher interface { type Publisher interface {
Start(ctx context.Context, chanLen uint) chan<- []byte Push(data []byte) error
} }
type pubHandler struct { type pubHandler struct {
client *Client client *Client
} }
func (p *pubHandler) Start(ctx context.Context, chanLen uint) chan<- []byte { func (p *pubHandler) Push(data []byte) error {
ch := make(chan []byte, chanLen)
go func() {
for {
select {
case <-ctx.Done():
for msg := range ch {
if err := p.push(msg); err != nil {
p.client.logger.Printf("Error publishing message (shutdown): %s", err)
}
}
p.client.logger.Println("Publisher stopped")
return
case msg := <-ch:
if err := p.push(msg); err != nil {
p.client.logger.Printf("Error publishing message: %s", err)
}
}
}
}()
return ch
}
func (p *pubHandler) push(data []byte) error {
p.client.mutex.Lock() p.client.mutex.Lock()
if !p.client.isReady { if !p.client.isReady {
p.client.mutex.Unlock() p.client.mutex.Unlock()

View file

@ -8,6 +8,10 @@ import (
func (c *Client) handleReconnect(addr string) { func (c *Client) handleReconnect(addr string) {
for { for {
c.mutex.Lock()
c.isReady = false
c.mutex.Unlock()
c.logger.Println("Connecting to server...") c.logger.Println("Connecting to server...")
conn, err := c.connect(addr) conn, err := c.connect(addr)
@ -41,6 +45,10 @@ func (c *Client) connect(addr string) (*amqp.Connection, error) {
func (c *Client) handleReInit(conn *amqp.Connection) bool { func (c *Client) handleReInit(conn *amqp.Connection) bool {
for { for {
c.mutex.Lock()
c.isReady = false
c.mutex.Unlock()
if err := c.init(conn); err != nil { if err := c.init(conn); err != nil {
c.logger.Printf("Failed to initialize connection: %s", err) c.logger.Printf("Failed to initialize connection: %s", err)
@ -78,16 +86,7 @@ func (c *Client) init(conn *amqp.Connection) error {
return err return err
} }
//_, err = ch.QueueDeclare(c.queueName, false, false, false, false, nil) _, err = ch.QueueDeclare(c.queueName, false, false, false, false, nil)
_, err = ch.QueueDeclare(
c.queueOptions.QueueName,
c.queueOptions.Durable,
c.queueOptions.AutoDelete,
c.queueOptions.Exclusive,
c.queueOptions.NoWait,
c.queueOptions.Args,
)
if err != nil { if err != nil {
return err return err
} }
@ -129,7 +128,7 @@ func (c *Client) unsafePush(data []byte) error {
return c.Channel.PublishWithContext( return c.Channel.PublishWithContext(
ctx, ctx,
"", "",
c.queueOptions.QueueName, c.queueName,
false, false,
false, false,
amqp.Publishing{ amqp.Publishing{
@ -156,7 +155,7 @@ func (c *Client) consume() (<-chan amqp.Delivery, error) {
} }
return c.Channel.Consume( return c.Channel.Consume(
c.queueOptions.QueueName, c.queueName,
"", "",
false, false,
false, false,