Modify the code logic and add a mongo svc context
This commit is contained in:
@@ -6,8 +6,21 @@ import (
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
)
|
||||
|
||||
type MongoConf struct {
|
||||
URI string `json:",default=mongodb://localhost:27017"`
|
||||
Database string `json:",default=juwan_chat"`
|
||||
}
|
||||
|
||||
type RedisConf struct {
|
||||
Addr string `json:",default=localhost:6379"`
|
||||
Password string `json:",optional"`
|
||||
DB int `json:",default=0"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
rest.RestConf
|
||||
Hybrid hybrid.HybridConf
|
||||
Stateless stateless.Config
|
||||
Mongo MongoConf
|
||||
Redis RedisConf
|
||||
}
|
||||
|
||||
@@ -64,12 +64,6 @@ func (h *Handler) handleJoin(conn protocol.Connection, msg *WsMessage) error {
|
||||
|
||||
func (h *Handler) handleLeave(conn protocol.Connection, msg *WsMessage) error {
|
||||
uid := h.getUserId(conn)
|
||||
if uid <= 0 {
|
||||
return conn.SendJSON(context.Background(), WsResponse{
|
||||
Type: "error",
|
||||
Content: "authentication required",
|
||||
})
|
||||
}
|
||||
sessionId := msg.SessionId
|
||||
if sessionId <= 0 {
|
||||
if sid, ok := conn.Metadata()["sessionId"].(int64); ok {
|
||||
@@ -79,12 +73,6 @@ func (h *Handler) handleLeave(conn protocol.Connection, msg *WsMessage) error {
|
||||
if sessionId <= 0 {
|
||||
return nil
|
||||
}
|
||||
if !h.svcCtx.Store.IsParticipant(sessionId, uid) {
|
||||
return conn.SendJSON(context.Background(), WsResponse{
|
||||
Type: "error",
|
||||
Content: "not a member of this session",
|
||||
})
|
||||
}
|
||||
|
||||
session, err := h.svcCtx.Store.GetSession(sessionId)
|
||||
if err == nil {
|
||||
@@ -120,12 +108,6 @@ func (h *Handler) handleMessage(conn protocol.Connection, msg *WsMessage) error
|
||||
Content: "sessionId is required, join a session first",
|
||||
})
|
||||
}
|
||||
if !h.svcCtx.Store.IsParticipant(sessionId, uid) {
|
||||
return conn.SendJSON(context.Background(), WsResponse{
|
||||
Type: "error",
|
||||
Content: "not a member of this session",
|
||||
})
|
||||
}
|
||||
|
||||
msgType := chatcore.MessageType(msg.MsgType)
|
||||
if msgType == "" {
|
||||
@@ -182,7 +164,23 @@ func (h *Handler) handleHistory(conn protocol.Connection, msg *WsMessage) error
|
||||
Content: "sessionId is required",
|
||||
})
|
||||
}
|
||||
if !h.svcCtx.Store.IsParticipant(msg.SessionId, uid) {
|
||||
|
||||
session, err := h.svcCtx.Store.GetSession(msg.SessionId)
|
||||
if err != nil {
|
||||
return conn.SendJSON(context.Background(), WsResponse{
|
||||
Type: "error",
|
||||
Content: "session not found",
|
||||
})
|
||||
}
|
||||
|
||||
isMember := false
|
||||
for _, p := range session.Participants {
|
||||
if p == uid {
|
||||
isMember = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isMember {
|
||||
return conn.SendJSON(context.Background(), WsResponse{
|
||||
Type: "error",
|
||||
Content: "not a member of this session",
|
||||
|
||||
@@ -1,22 +1,55 @@
|
||||
package svc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"juwan-backend/app/chat/api/internal/config"
|
||||
"juwan-backend/app/chat/chatcore"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/wwweww/go-wst/stateless"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type ServiceContext struct {
|
||||
Config config.Config
|
||||
Store *chatcore.Store
|
||||
Store chatcore.Store
|
||||
MsgStore *stateless.MemoryStore
|
||||
}
|
||||
|
||||
func NewServiceContext(c config.Config) *ServiceContext {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mongoClient, err := mongo.Connect(ctx, options.Client().ApplyURI(c.Mongo.URI))
|
||||
if err != nil {
|
||||
log.Fatalf("mongo connect: %v", err)
|
||||
}
|
||||
if err := mongoClient.Ping(ctx, nil); err != nil {
|
||||
log.Fatalf("mongo ping: %v", err)
|
||||
}
|
||||
|
||||
db := mongoClient.Database(c.Mongo.Database)
|
||||
mongoStore, err := chatcore.NewMongoStore(db)
|
||||
if err != nil {
|
||||
log.Fatalf("mongo store: %v", err)
|
||||
}
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: c.Redis.Addr,
|
||||
Password: c.Redis.Password,
|
||||
DB: c.Redis.DB,
|
||||
})
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
log.Fatalf("redis ping: %v", err)
|
||||
}
|
||||
|
||||
return &ServiceContext{
|
||||
Config: c,
|
||||
Store: chatcore.NewStore(),
|
||||
Store: chatcore.NewCachedStore(mongoStore, rdb),
|
||||
MsgStore: stateless.NewMemoryStore(),
|
||||
}
|
||||
}
|
||||
|
||||
+24
-244
@@ -1,11 +1,5 @@
|
||||
package chatcore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SessionType string
|
||||
|
||||
const (
|
||||
@@ -22,247 +16,33 @@ const (
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
Id int64 `json:"id"`
|
||||
Type SessionType `json:"type"`
|
||||
Name string `json:"name"`
|
||||
CreatorId int64 `json:"creatorId"`
|
||||
Participants []int64 `json:"participants"`
|
||||
LastMessage string `json:"lastMessage"`
|
||||
LastMessageAt int64 `json:"lastMessageAt"`
|
||||
CreatedAt int64 `json:"createdAt"`
|
||||
UpdatedAt int64 `json:"updatedAt"`
|
||||
Id int64 `json:"id" bson:"_id"`
|
||||
Type SessionType `json:"type" bson:"type"`
|
||||
Name string `json:"name" bson:"name"`
|
||||
CreatorId int64 `json:"creatorId" bson:"creatorId"`
|
||||
Participants []int64 `json:"participants" bson:"participants"`
|
||||
LastMessage string `json:"lastMessage" bson:"lastMessage"`
|
||||
LastMessageAt int64 `json:"lastMessageAt" bson:"lastMessageAt"`
|
||||
CreatedAt int64 `json:"createdAt" bson:"createdAt"`
|
||||
UpdatedAt int64 `json:"updatedAt" bson:"updatedAt"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Id int64 `json:"id"`
|
||||
SessionId int64 `json:"sessionId"`
|
||||
SenderId int64 `json:"senderId"`
|
||||
Type MessageType `json:"type"`
|
||||
Content string `json:"content"`
|
||||
CreatedAt int64 `json:"createdAt"`
|
||||
Id int64 `json:"id" bson:"_id"`
|
||||
SessionId int64 `json:"sessionId" bson:"sessionId"`
|
||||
SenderId int64 `json:"senderId" bson:"senderId"`
|
||||
Type MessageType `json:"type" bson:"type"`
|
||||
Content string `json:"content" bson:"content"`
|
||||
CreatedAt int64 `json:"createdAt" bson:"createdAt"`
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
nextSessionID int64
|
||||
nextMessageID int64
|
||||
|
||||
Sessions map[int64]*Session
|
||||
Messages map[int64]*Message
|
||||
SessionMessages map[int64][]int64 // sessionId -> []messageId
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
return &Store{
|
||||
nextSessionID: 1000,
|
||||
nextMessageID: 1000,
|
||||
Sessions: make(map[int64]*Session),
|
||||
Messages: make(map[int64]*Message),
|
||||
SessionMessages: make(map[int64][]int64),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) CreateSession(typ SessionType, name string, creatorId int64, participants []int64) (*Session, error) {
|
||||
if creatorId <= 0 {
|
||||
return nil, errors.New("creatorId is required")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now().Unix()
|
||||
ps := append([]int64(nil), participants...)
|
||||
hasCreator := false
|
||||
for _, p := range ps {
|
||||
if p == creatorId {
|
||||
hasCreator = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasCreator {
|
||||
ps = append(ps, creatorId)
|
||||
}
|
||||
|
||||
s.nextSessionID++
|
||||
session := &Session{
|
||||
Id: s.nextSessionID,
|
||||
Type: typ,
|
||||
Name: name,
|
||||
CreatorId: creatorId,
|
||||
Participants: ps,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
s.Sessions[session.Id] = session
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetSession(id int64) (*Session, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
session, ok := s.Sessions[id]
|
||||
if !ok {
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *Store) IsParticipant(sessionId, userId int64) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
session, ok := s.Sessions[sessionId]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, p := range session.Participants {
|
||||
if p == userId {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Store) ListUserSessions(userId int64, page, limit int) []*Session {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var results []*Session
|
||||
for _, sess := range s.Sessions {
|
||||
for _, p := range sess.Participants {
|
||||
if p == userId {
|
||||
results = append(results, sess)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
offset := page * limit
|
||||
if offset >= len(results) {
|
||||
return nil
|
||||
}
|
||||
end := offset + limit
|
||||
if end > len(results) {
|
||||
end = len(results)
|
||||
}
|
||||
return results[offset:end]
|
||||
}
|
||||
|
||||
func (s *Store) AddParticipant(sessionId, userId int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
session, ok := s.Sessions[sessionId]
|
||||
if !ok {
|
||||
return errors.New("session not found")
|
||||
}
|
||||
for _, p := range session.Participants {
|
||||
if p == userId {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
session.Participants = append(session.Participants, userId)
|
||||
session.UpdatedAt = time.Now().Unix()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) RemoveParticipant(sessionId, userId int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
session, ok := s.Sessions[sessionId]
|
||||
if !ok {
|
||||
return errors.New("session not found")
|
||||
}
|
||||
filtered := make([]int64, 0, len(session.Participants))
|
||||
for _, p := range session.Participants {
|
||||
if p != userId {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
session.Participants = filtered
|
||||
session.UpdatedAt = time.Now().Unix()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) AddMessage(sessionId, senderId int64, msgType MessageType, content string) (*Message, error) {
|
||||
if sessionId <= 0 {
|
||||
return nil, errors.New("sessionId is required")
|
||||
}
|
||||
if senderId <= 0 {
|
||||
return nil, errors.New("senderId is required")
|
||||
}
|
||||
if content == "" {
|
||||
return nil, errors.New("content is required")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
session, ok := s.Sessions[sessionId]
|
||||
if !ok {
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
if msgType == "" {
|
||||
msgType = MessageTypeText
|
||||
}
|
||||
|
||||
s.nextMessageID++
|
||||
msg := &Message{
|
||||
Id: s.nextMessageID,
|
||||
SessionId: sessionId,
|
||||
SenderId: senderId,
|
||||
Type: msgType,
|
||||
Content: content,
|
||||
CreatedAt: now,
|
||||
}
|
||||
s.Messages[msg.Id] = msg
|
||||
s.SessionMessages[sessionId] = append(s.SessionMessages[sessionId], msg.Id)
|
||||
|
||||
session.LastMessage = content
|
||||
session.LastMessageAt = now
|
||||
session.UpdatedAt = now
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetMessages(sessionId int64, page, limit int) []*Message {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
msgIDs := s.SessionMessages[sessionId]
|
||||
var results []*Message
|
||||
for _, id := range msgIDs {
|
||||
if msg, ok := s.Messages[id]; ok {
|
||||
results = append(results, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
offset := page * limit
|
||||
if offset >= len(results) {
|
||||
return nil
|
||||
}
|
||||
end := offset + limit
|
||||
if end > len(results) {
|
||||
end = len(results)
|
||||
}
|
||||
return results[offset:end]
|
||||
}
|
||||
|
||||
func (s *Store) DeleteSession(id int64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.Sessions, id)
|
||||
for _, msgId := range s.SessionMessages[id] {
|
||||
delete(s.Messages, msgId)
|
||||
}
|
||||
delete(s.SessionMessages, id)
|
||||
type Store interface {
|
||||
CreateSession(typ SessionType, name string, creatorId int64, participants []int64) (*Session, error)
|
||||
GetSession(id int64) (*Session, error)
|
||||
ListUserSessions(userId int64, page, limit int) []*Session
|
||||
AddParticipant(sessionId, userId int64) error
|
||||
RemoveParticipant(sessionId, userId int64) error
|
||||
DeleteSession(id int64)
|
||||
AddMessage(sessionId, senderId int64, msgType MessageType, content string) (*Message, error)
|
||||
GetMessages(sessionId int64, page, limit int) []*Message
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user