SseEmitterHelper.java 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package com.webchat.common.helper;
  2. import lombok.extern.slf4j.Slf4j;
  3. import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
  4. import java.io.IOException;
  5. import java.util.concurrent.ConcurrentHashMap;
  6. import java.util.function.Consumer;
  7. /**
  8. * @Author 程序员七七
  9. * @webSite https://www.coderutil.com
  10. * @Date 2024/10/29 23:27
  11. * @description
  12. */
  13. @Slf4j
  14. public class SseEmitterHelper {
  15. /**
  16. * 维护用户对话的SSE
  17. */
  18. private static ConcurrentHashMap<String, ConcurrentHashMap<String, SseEmitter>> sseEmitterMap = new ConcurrentHashMap<>();
  19. /**
  20. * 判断当前用户SSE链接是否在当前节点
  21. *
  22. *
  23. * @param biz
  24. * @param userId
  25. * @return
  26. */
  27. public static boolean isExist(String biz, String userId) {
  28. ConcurrentHashMap<String, SseEmitter> userSseEmitter = sseEmitterMap.get(biz);
  29. return userSseEmitter.get(userId) != null;
  30. }
  31. /**
  32. * 获取用户的 SseEmitter 对象,如果不存在重新创建一个
  33. *
  34. * @param userId
  35. * @return
  36. */
  37. public static SseEmitter get(String biz, String userId) {
  38. ConcurrentHashMap<String, SseEmitter> userSseEmitter = sseEmitterMap.get(biz);
  39. if (userSseEmitter == null) {
  40. userSseEmitter = new ConcurrentHashMap<>();
  41. }
  42. SseEmitter sseEmitter = userSseEmitter.get(userId);
  43. if (sseEmitter == null) {
  44. sseEmitter = create(biz, userId);
  45. }
  46. return sseEmitter;
  47. }
  48. /**
  49. * 删除用户 SseEmitter 对象
  50. *
  51. * @param userId
  52. */
  53. public static void remove(String biz, String userId) {
  54. ConcurrentHashMap<String, SseEmitter> userSseEmitter = sseEmitterMap.get(biz);
  55. userSseEmitter.remove(userId);
  56. }
  57. /**
  58. * 创建SseEmitter
  59. *
  60. * @param userId
  61. * @return
  62. */
  63. private static SseEmitter create(String biz, String userId) {
  64. SseEmitter sseEmitter = new SseEmitter();
  65. ConcurrentHashMap<String, SseEmitter> userSseEmitter = sseEmitterMap.get(biz);
  66. if (userSseEmitter == null) {
  67. userSseEmitter = new ConcurrentHashMap<>();
  68. }
  69. userSseEmitter.put(userId, sseEmitter);
  70. sseEmitterMap.put(biz, userSseEmitter);
  71. sseEmitter.onCompletion(completionCallBack(biz, userId));
  72. sseEmitter.onError(errorCallBack(biz, userId));
  73. sseEmitter.onTimeout(timeoutCallBack(biz, userId));
  74. log.info("SSE Connection created =====> biz={}, userId={}", biz, userId);
  75. return sseEmitter;
  76. }
  77. private static Runnable completionCallBack(String biz, String userId) {
  78. return () -> {
  79. log.info("结束连接=====> userId={}", userId);
  80. remove(biz, userId);
  81. };
  82. }
  83. private static Runnable timeoutCallBack(String biz, String userId){
  84. return ()->{
  85. log.info("连接超时=====> userId={}", userId);
  86. remove(biz, userId);
  87. };
  88. }
  89. private static Consumer<Throwable> errorCallBack(String biz, String userId){
  90. return throwable -> {
  91. log.info("连接失败=====> userId={}", userId);
  92. remove(biz, userId);
  93. };
  94. }
  95. /**
  96. * sse 消息推送
  97. *
  98. * @param biz
  99. * @param userId
  100. * @param message
  101. */
  102. public static void send(String biz, String userId, String message) {
  103. try {
  104. SseEmitter sseEmitter = get(biz, userId);
  105. sseEmitter.send(message);
  106. sseEmitter.send("finished");
  107. } catch (IOException ex) {
  108. throw new RuntimeException(ex);
  109. }
  110. }
  111. }