diff --git a/channel.go b/channel.go new file mode 100644 index 0000000..b1fc4fa --- /dev/null +++ b/channel.go @@ -0,0 +1,40 @@ +package webreal + +import "sync" + +type Channel struct { + mu sync.RWMutex + subscribers map[string]*Client +} + +func NewChannel() *Channel { + return &Channel{ + subscribers: map[string]*Client{}, + } +} + +// 添加客户端 +func (c *Channel) Add(client *Client) { + c.mu.Lock() + defer c.mu.Unlock() + + c.subscribers[client.id] = client +} + +// 移除客户端 +func (c *Channel) Remove(client *Client) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.subscribers, client.id) +} + +// 遍历客户端 +func (c *Channel) Range(f func(client *Client)) { + c.mu.RLock() + defer c.mu.RUnlock() + + for _, client := range c.subscribers { + f(client) + } +} diff --git a/client.go b/client.go index b34fe75..7220f77 100644 --- a/client.go +++ b/client.go @@ -103,19 +103,16 @@ func (c *Client) close() { } // 订阅 -func (c *Client) Subscribe(channel string) bool { +func (c *Client) Subscribe(channel string) { c.mu.Lock() defer c.mu.Unlock() - _, found := c.channels[channel] - if found { - return false + if _, found := c.channels[channel]; found { + return } - if c.hub.Subscribe(channel, c) { - c.channels[channel] = struct{}{} - return true - } - return false + + c.hub.Subscribe(channel, c) + c.channels[channel] = struct{}{} } // 退订 @@ -123,9 +120,8 @@ func (c *Client) Unsubscribe(channel string) { c.mu.Lock() defer c.mu.Unlock() - if c.hub.Unsubscribe(channel, c) { - delete(c.channels, channel) - } + c.hub.Unsubscribe(channel, c) + delete(c.channels, channel) } // 退订所有 @@ -139,26 +135,6 @@ func (c *Client) UnsubscribeAll() { c.channels = map[string]struct{}{} } -// 获取已订阅的主题列表 -func (c *Client) Channels() []string { - c.mu.Lock() - defer c.mu.Unlock() - - var channels []string - for key := range c.channels { - channels = append(channels, key) - } - return channels -} - -// 获取已订阅长度 -func (c *Client) ChannelsLen() int { - c.mu.Lock() - defer c.mu.Unlock() - - return len(c.channels) -} - // 获取客户端ID func (c *Client) ID() string { return c.id diff --git a/subscription.go b/subscription.go index 2b90107..d876dcc 100644 --- a/subscription.go +++ b/subscription.go @@ -4,57 +4,44 @@ import "sync" // 订阅中心,存储的所有频道和客户端订阅的对应关系 type SubscriptionHub struct { - mu sync.RWMutex - - // channel => clientId => Client - subscribers map[string]map[string]*Client + mu sync.Mutex + channels map[string]*Channel } func NewSubscriptionHub() *SubscriptionHub { return &SubscriptionHub{ - subscribers: map[string]map[string]*Client{}, + channels: map[string]*Channel{}, } } // 订阅主题 -func (s *SubscriptionHub) Subscribe(channel string, client *Client) bool { +func (s *SubscriptionHub) Subscribe(channel string, client *Client) { s.mu.Lock() defer s.mu.Unlock() - if _, ok := s.subscribers[channel]; !ok { - s.subscribers[channel] = map[string]*Client{} + if _, found := s.channels[channel]; !found { + s.channels[channel] = NewChannel() } - s.subscribers[channel][client.id] = client - return true + s.channels[channel].Add(client) } // 退订主题 -func (s *SubscriptionHub) Unsubscribe(channel string, client *Client) bool { - s.mu.Lock() - defer s.mu.Unlock() - - if _, ok := s.subscribers[channel]; !ok { - return false - } - if _, ok := s.subscribers[channel][client.id]; !ok { - return false +func (s *SubscriptionHub) Unsubscribe(channel string, client *Client) { + if _, found := s.channels[channel]; !found { + return } - delete(s.subscribers[channel], client.id) - return true + s.channels[channel].Remove(client) } // 向客户端推送主题消息 func (s *SubscriptionHub) Publish(channel string, msg []byte) { - s.mu.RLock() - defer s.mu.RUnlock() - - if _, ok := s.subscribers[channel]; !ok { + if _, found := s.channels[channel]; !found { return } - for _, c := range s.subscribers[channel] { + s.channels[channel].Range(func(c *Client) { c.Write(msg) - } + }) }