144 lines
3.6 KiB
Go
144 lines
3.6 KiB
Go
package chat
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strconv"
|
|
|
|
"juwan-backend/app/chat/api/internal/svc"
|
|
|
|
"github.com/wwweww/go-wst/hybrid"
|
|
"github.com/wwweww/go-wst/protocol"
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
)
|
|
|
|
type WsMessage struct {
|
|
Type string `json:"type"`
|
|
SessionId int64 `json:"sessionId,omitempty"`
|
|
TargetId int64 `json:"targetId,omitempty"`
|
|
Content string `json:"content,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
MsgType string `json:"msgType,omitempty"`
|
|
}
|
|
|
|
type WsResponse struct {
|
|
Type string `json:"type"`
|
|
SessionId int64 `json:"sessionId,omitempty"`
|
|
SenderId int64 `json:"senderId,omitempty"`
|
|
Content string `json:"content,omitempty"`
|
|
Data interface{} `json:"data,omitempty"`
|
|
}
|
|
|
|
type Handler struct {
|
|
svcCtx *svc.ServiceContext
|
|
server *hybrid.Server
|
|
}
|
|
|
|
var _ protocol.StatefulHandler = (*Handler)(nil)
|
|
var _ protocol.StatelessHandler = (*Handler)(nil)
|
|
|
|
func NewHandler(svcCtx *svc.ServiceContext) *Handler {
|
|
return &Handler{
|
|
svcCtx: svcCtx,
|
|
}
|
|
}
|
|
|
|
func (h *Handler) SetServer(s *hybrid.Server) {
|
|
h.server = s
|
|
}
|
|
|
|
func (h *Handler) OnConnect(conn protocol.Connection) error {
|
|
logx.Infof("chat connected: id=%s userID=%s protocol=%s", conn.ID(), conn.UserID(), conn.Protocol())
|
|
if uid := conn.UserID(); uid != "" {
|
|
h.server.BindUser(conn, uid)
|
|
}
|
|
return conn.SendJSON(context.Background(), WsResponse{
|
|
Type: "connected",
|
|
Content: "chat service connected",
|
|
})
|
|
}
|
|
|
|
func (h *Handler) OnMessage(conn protocol.Connection, raw []byte) error {
|
|
var msg WsMessage
|
|
if err := json.Unmarshal(raw, &msg); err != nil {
|
|
return conn.SendJSON(context.Background(), WsResponse{
|
|
Type: "error",
|
|
Content: "invalid message format",
|
|
})
|
|
}
|
|
|
|
switch msg.Type {
|
|
case "create_group":
|
|
return h.handleCreateGroup(conn, &msg)
|
|
case "create_dm":
|
|
return h.handleCreateDM(conn, &msg)
|
|
case "join":
|
|
return h.handleJoin(conn, &msg)
|
|
case "leave":
|
|
return h.handleLeave(conn, &msg)
|
|
case "message":
|
|
return h.handleMessage(conn, &msg)
|
|
case "history":
|
|
return h.handleHistory(conn, &msg)
|
|
default:
|
|
return conn.SendJSON(context.Background(), WsResponse{
|
|
Type: "error",
|
|
Content: fmt.Sprintf("unknown message type: %s", msg.Type),
|
|
})
|
|
}
|
|
}
|
|
|
|
func (h *Handler) OnDisconnect(conn protocol.Connection, err error) {
|
|
logx.Infof("chat disconnected: userID=%s err=%v", conn.UserID(), err)
|
|
}
|
|
|
|
func (h *Handler) Fetch(ctx context.Context, req protocol.FetchRequest) ([]protocol.Message, error) {
|
|
return h.svcCtx.MsgStore.Fetch(ctx, req.UserID, req.SinceID, req.Limit)
|
|
}
|
|
|
|
func (h *Handler) Send(ctx context.Context, req protocol.SendRequest) error {
|
|
msg := protocol.Message{
|
|
Type: "message",
|
|
Topic: req.Topic,
|
|
Data: req.Data,
|
|
}
|
|
return h.svcCtx.MsgStore.Store(ctx, req.UserID, msg)
|
|
}
|
|
|
|
func (h *Handler) getUserId(conn protocol.Connection) int64 {
|
|
uid, _ := strconv.ParseInt(conn.UserID(), 10, 64)
|
|
return uid
|
|
}
|
|
|
|
func (h *Handler) broadcastToParticipants(participants []int64, resp WsResponse) {
|
|
data, err := json.Marshal(resp)
|
|
if err != nil {
|
|
logx.Errorf("marshal error: %v", err)
|
|
return
|
|
}
|
|
|
|
userIDs := make([]string, len(participants))
|
|
for i, p := range participants {
|
|
userIDs[i] = strconv.FormatInt(p, 10)
|
|
}
|
|
|
|
if err := h.server.BroadcastTo(userIDs, data); err != nil {
|
|
logx.Errorf("broadcastTo failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func (h *Handler) storeOfflineMessage(userID string, resp WsResponse) {
|
|
data, err := json.Marshal(resp)
|
|
if err != nil {
|
|
return
|
|
}
|
|
msg := protocol.Message{
|
|
Type: "chat",
|
|
Data: data,
|
|
}
|
|
if storeErr := h.svcCtx.MsgStore.Store(context.Background(), userID, msg); storeErr != nil {
|
|
logx.Errorf("store offline msg for %s failed: %v", userID, storeErr)
|
|
}
|
|
}
|