garm/websocket/websocket.go
Gabriel Adrian Samfira 4d7fcbe23a Safely close the quit channel
Prevent accidental closure of an already closed channel.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
2024-01-06 14:15:52 +00:00

143 lines
2.5 KiB
Go

package websocket
import (
"context"
"fmt"
"sync"
"time"
)
func NewHub(ctx context.Context) *Hub {
return &Hub{
clients: map[string]*Client{},
broadcast: make(chan []byte, 100),
register: make(chan *Client, 100),
unregister: make(chan *Client, 100),
ctx: ctx,
closed: make(chan struct{}),
quit: make(chan struct{}),
}
}
type Hub struct {
ctx context.Context
closed chan struct{}
quit chan struct{}
// Registered clients.
clients map[string]*Client
// Inbound messages from the clients.
broadcast chan []byte
// Register requests from the clients.
register chan *Client
// Unregister requests from clients.
unregister chan *Client
mux sync.Mutex
once sync.Once
}
func (h *Hub) run() {
defer func() {
close(h.closed)
}()
for {
select {
case <-h.quit:
return
case <-h.ctx.Done():
return
case client := <-h.register:
if client != nil {
h.mux.Lock()
h.clients[client.id] = client
h.mux.Unlock()
}
case client := <-h.unregister:
if client != nil {
h.mux.Lock()
if _, ok := h.clients[client.id]; ok {
client.conn.Close()
close(client.send)
delete(h.clients, client.id)
}
h.mux.Unlock()
}
case message := <-h.broadcast:
staleClients := []string{}
for id, client := range h.clients {
if client == nil {
staleClients = append(staleClients, id)
continue
}
select {
case client.send <- message:
case <-time.After(5 * time.Second):
staleClients = append(staleClients, id)
}
}
if len(staleClients) > 0 {
h.mux.Lock()
for _, id := range staleClients {
if client, ok := h.clients[id]; ok {
if client != nil {
client.conn.Close()
close(client.send)
}
delete(h.clients, id)
}
}
h.mux.Unlock()
}
}
}
}
func (h *Hub) Register(client *Client) error {
h.register <- client
return nil
}
func (h *Hub) Write(msg []byte) (int, error) {
tmp := make([]byte, len(msg))
copy(tmp, msg)
select {
case <-time.After(5 * time.Second):
return 0, fmt.Errorf("timed out sending message to client")
case h.broadcast <- tmp:
}
return len(tmp), nil
}
func (h *Hub) Start() error {
go h.run()
return nil
}
func (h *Hub) Close() error {
h.once.Do(func() {
close(h.quit)
})
return nil
}
func (h *Hub) Stop() error {
h.Close()
select {
case <-h.closed:
return nil
case <-time.After(60 * time.Second):
return fmt.Errorf("timed out waiting for hub stop")
}
}
func (h *Hub) Wait() {
select {
case <-h.closed:
case <-time.After(60 * time.Second):
}
}