diff --git a/websocket/client.go b/websocket/client.go index 5b80ba81..657aa49e 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -58,8 +58,6 @@ func NewClient(ctx context.Context, conn *websocket.Conn) (*Client, error) { userID: user, passwordGeneration: generation, consumer: consumer, - done: make(chan struct{}), - send: make(chan []byte, 100), }, nil } @@ -116,6 +114,8 @@ func (c *Client) Start() error { defer c.mux.Unlock() c.running = true + c.send = make(chan []byte, 100) + c.done = make(chan struct{}) go c.runWatcher() go c.clientReader() diff --git a/websocket/websocket.go b/websocket/websocket.go index 57820449..14b5e785 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -28,14 +28,15 @@ type Hub struct { // Inbound messages from the clients. broadcast chan []byte - mux sync.Mutex - once sync.Once + mux sync.Mutex + running bool + once sync.Once } func (h *Hub) run() { - defer func() { - close(h.closed) - }() + defer close(h.closed) + defer h.Stop() + for { select { case <-h.quit: @@ -59,8 +60,7 @@ func (h *Hub) run() { for _, id := range staleClients { if client, ok := h.clients[id]; ok { if client != nil { - client.conn.Close() - close(client.send) + client.Stop() } delete(h.clients, id) } @@ -105,6 +105,13 @@ func (h *Hub) Unregister(client *Client) error { } func (h *Hub) Write(msg []byte) (int, error) { + h.mux.Lock() + if !h.running { + h.mux.Unlock() + return 0, fmt.Errorf("websocket writer is not running") + } + h.mux.Unlock() + tmp := make([]byte, len(msg)) copy(tmp, msg) timer := time.NewTimer(5 * time.Second) @@ -118,6 +125,15 @@ func (h *Hub) Write(msg []byte) (int, error) { } func (h *Hub) Start() error { + h.mux.Lock() + defer h.mux.Unlock() + + if h.running { + return nil + } + + h.running = true + go h.run() return nil } @@ -130,11 +146,22 @@ func (h *Hub) Close() error { } func (h *Hub) Stop() error { + h.mux.Lock() + defer h.mux.Unlock() + + if !h.running { + return nil + } + + h.running = false h.Close() return h.Wait() } func (h *Hub) Wait() error { + if !h.running { + return nil + } timer := time.NewTimer(60 * time.Second) defer timer.Stop() select {