195 lines
6.7 KiB
Java
195 lines
6.7 KiB
Java
package com.example.websocket;
|
||
|
||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||
import com.example.entity.ChatMessage;
|
||
import com.example.service.ChatMessageService;
|
||
import jakarta.annotation.Resource;
|
||
import jakarta.websocket.*;
|
||
import jakarta.websocket.server.PathParam;
|
||
import jakarta.websocket.server.ServerEndpoint;
|
||
import org.slf4j.Logger;
|
||
import org.slf4j.LoggerFactory;
|
||
import org.springframework.stereotype.Component;
|
||
|
||
import java.io.IOException;
|
||
import java.util.Date;
|
||
import java.util.Map;
|
||
import java.util.concurrent.ConcurrentHashMap;
|
||
|
||
@ServerEndpoint("/ws/chat/{userId}/{userType}")
|
||
@Component
|
||
public class ChatWebSocket {
|
||
|
||
private static final Logger log = LoggerFactory.getLogger(ChatWebSocket.class);
|
||
|
||
// 静态变量,用来记录当前在线连接数
|
||
private static int onlineCount = 0;
|
||
|
||
// 用户ID和WebSocket的映射关系
|
||
private static Map<String, ChatWebSocket> clients = new ConcurrentHashMap<>();
|
||
|
||
// 与某个客户端的连接会话,用于发送数据
|
||
private Session session;
|
||
|
||
// 当前连接用户ID
|
||
private Integer userId;
|
||
|
||
// 当前连接用户类型
|
||
private String userType;
|
||
|
||
// 注入Service,因为@ServerEndpoint不支持直接注入,需要通过静态变量
|
||
private static ChatMessageService chatMessageService;
|
||
|
||
private static final ObjectMapper objectMapper = new ObjectMapper();
|
||
|
||
@Resource
|
||
public void setChatMessageService(ChatMessageService chatMessageService) {
|
||
ChatWebSocket.chatMessageService = chatMessageService;
|
||
}
|
||
|
||
/**
|
||
* 连接建立成功调用的方法
|
||
*/
|
||
@OnOpen
|
||
public void onOpen(Session session, @PathParam("userId") Integer userId, @PathParam("userType") String userType) {
|
||
this.session = session;
|
||
this.userId = userId;
|
||
this.userType = userType;
|
||
|
||
// 将当前WebSocket对象加入到Map中
|
||
String key = userId + ":" + userType;
|
||
clients.put(key, this);
|
||
|
||
addOnlineCount();
|
||
log.info("有新连接加入,当前在线人数为:{}", getOnlineCount());
|
||
}
|
||
|
||
/**
|
||
* 连接关闭调用的方法
|
||
*/
|
||
@OnClose
|
||
public void onClose() {
|
||
// 从Map中移除
|
||
String key = userId + ":" + userType;
|
||
clients.remove(key);
|
||
|
||
subOnlineCount();
|
||
log.info("有一连接关闭,当前在线人数为:{}", getOnlineCount());
|
||
}
|
||
|
||
/**
|
||
* 收到客户端消息后调用的方法
|
||
*/
|
||
@OnMessage
|
||
public void onMessage(String message, Session session) {
|
||
log.info("收到来自用户{}:{}的消息:{}", userId, userType, message);
|
||
|
||
try {
|
||
// 解析消息
|
||
ChatMessage chatMessage = objectMapper.readValue(message, ChatMessage.class);
|
||
|
||
// 直接设置发送者信息,不再通过TokenUtils获取
|
||
chatMessage.setSenderId(userId);
|
||
chatMessage.setSenderType(userType);
|
||
chatMessage.setSendTime(new Date());
|
||
chatMessage.setIsRead(false);
|
||
|
||
// 保存临时ID,用于前端识别消息
|
||
String tempId = null;
|
||
try {
|
||
// 从消息中提取tempId字段
|
||
Map<String, Object> messageMap = objectMapper.readValue(message, Map.class);
|
||
if (messageMap.containsKey("tempId")) {
|
||
tempId = messageMap.get("tempId").toString();
|
||
}
|
||
} catch (Exception e) {
|
||
log.warn("提取tempId失败", e);
|
||
}
|
||
|
||
// 保存消息到数据库
|
||
try {
|
||
// 检查发送者和接收者是否有关联关系并保存消息
|
||
chatMessageService.directSaveMessage(chatMessage);
|
||
|
||
// 将保存后的消息转为Map,以便添加tempId
|
||
Map<String, Object> responseMap = objectMapper.convertValue(chatMessage, Map.class);
|
||
if (tempId != null) {
|
||
responseMap.put("tempId", tempId);
|
||
}
|
||
String responseJson = objectMapper.writeValueAsString(responseMap);
|
||
|
||
// 转发消息给接收者
|
||
String receiverKey = chatMessage.getReceiverId() + ":" + chatMessage.getReceiverType();
|
||
ChatWebSocket receiverSocket = clients.get(receiverKey);
|
||
if (receiverSocket != null) {
|
||
// 接收者在线,发送消息
|
||
receiverSocket.sendMessage(responseJson);
|
||
}
|
||
|
||
// 同时也返回给发送者
|
||
sendMessage(responseJson);
|
||
} catch (Exception e) {
|
||
log.error("处理消息时发生错误", e);
|
||
// 发送错误消息给客户端
|
||
ChatMessage errorMessage = new ChatMessage();
|
||
errorMessage.setSenderId(0);
|
||
errorMessage.setSenderType("system");
|
||
errorMessage.setReceiverId(userId);
|
||
errorMessage.setReceiverType(userType);
|
||
errorMessage.setContent("发送消息失败:" + e.getMessage());
|
||
errorMessage.setSendTime(new Date());
|
||
|
||
// 将错误消息转为Map,以便添加tempId
|
||
Map<String, Object> errorMap = objectMapper.convertValue(errorMessage, Map.class);
|
||
if (tempId != null) {
|
||
errorMap.put("tempId", tempId);
|
||
}
|
||
sendMessage(objectMapper.writeValueAsString(errorMap));
|
||
}
|
||
} catch (Exception e) {
|
||
log.error("处理消息时发生错误", e);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 发生错误时调用
|
||
*/
|
||
@OnError
|
||
public void onError(Session session, Throwable error) {
|
||
log.error("发生错误", error);
|
||
}
|
||
|
||
/**
|
||
* 发送消息
|
||
*/
|
||
public void sendMessage(String message) {
|
||
try {
|
||
this.session.getBasicRemote().sendText(message);
|
||
} catch (IOException e) {
|
||
log.error("发送消息失败", e);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 服务端主动推送消息
|
||
*/
|
||
public static void sendMessage(Integer userId, String userType, String message) {
|
||
String key = userId + ":" + userType;
|
||
ChatWebSocket socket = clients.get(key);
|
||
if (socket != null) {
|
||
socket.sendMessage(message);
|
||
}
|
||
}
|
||
|
||
public static synchronized int getOnlineCount() {
|
||
return onlineCount;
|
||
}
|
||
|
||
public static synchronized void addOnlineCount() {
|
||
ChatWebSocket.onlineCount++;
|
||
}
|
||
|
||
public static synchronized void subOnlineCount() {
|
||
ChatWebSocket.onlineCount--;
|
||
}
|
||
} |