Просмотр исходного кода

文档向量化之前支持切片处理

wangqi49 2 недель назад
Родитель
Сommit
6a641c2f13

+ 1 - 1
webchat-aigc/src/main/java/com/webchat/aigc/llm/AiBotQAService.java

@@ -50,7 +50,7 @@ public class AiBotQAService {
     /**
      * 相似度限制条件
      */
-    private static final float MIN_SCORE = 0.8f;
+    private static final float MIN_SCORE = 0.4f;
 
     private static final String AIBOT_ROLE = "你是Chat4j项目的Ai助手,任务是根据用户给出的参考数据,基于输入query,精准的给出答案。";
 

+ 8 - 3
webchat-aigc/src/main/java/com/webchat/aigc/llm/AlibabaEmbeddingModel.java

@@ -38,11 +38,16 @@ public class AlibabaEmbeddingModel extends AbstractEmbeddingModel {
     public List<float[]> embed(List<String> texts) throws Exception {
         TextEmbeddingParam param = TextEmbeddingParam
                 .builder()
+                .apiKey(embeddingPropertiesConfig.getApiKey())
                 .model(TextEmbedding.Models.TEXT_EMBEDDING_V3)
-                .texts(Arrays.asList("风急天高猿啸哀", "渚清沙白鸟飞回", "无边落木萧萧下", "不尽长江滚滚来")).build();
+                .texts(texts).build();
         TextEmbedding textEmbedding = new TextEmbedding();
         TextEmbeddingResult result = textEmbedding.call(param);
-        System.out.println(result);
-        return null;
+        return result.getOutput().getEmbeddings().stream().map(embed -> {
+            List<Double> embeddings = embed.getEmbedding();
+            List<Float> fEmbeddings = embeddings.stream().map(textEmbed ->
+                                                                    textEmbed.floatValue()).toList();
+            return ArrayUtils.toPrimitive(fEmbeddings.toArray(new Float[0]));
+        }).toList();
     }
 }

Разница между файлами не показана из-за своего большого размера
+ 120 - 0
webchat-common/src/main/java/com/webchat/common/util/HtmlSplitter.java


+ 54 - 19
webchat-search/src/main/java/com/webchat/search/service/voctor/ArticleMilvusService.java

@@ -5,6 +5,7 @@ import com.google.gson.Gson;
 import com.google.gson.JsonObject;
 import com.webchat.common.constants.MilvusConstants;
 import com.webchat.common.enums.EmbeddingModelEnum;
+import com.webchat.common.util.HtmlSplitter;
 import com.webchat.common.util.JsonUtil;
 import com.webchat.domain.dto.search.SyncSearchEngineDTO;
 import com.webchat.domain.dto.search.SyncSearchEngineListDTO;
@@ -24,7 +25,7 @@ import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.collections4.CollectionUtils;
 import org.springframework.stereotype.Service;
 
-import java.lang.reflect.Array;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
@@ -189,7 +190,7 @@ public class ArticleMilvusService extends AbstractMilvusQueueService<SyncSearchE
         HybridSearchReq request = HybridSearchReq.builder()
                 .collectionName(this.collectionName())
                 .searchRequests(Lists.newArrayList(summaryReq, contentReq))
-                .ranker(new WeightedRanker(Arrays.asList(0.6f, 0.4f)))
+                .ranker(new WeightedRanker(Arrays.asList(0.5f, 0.5f)))
                 .topK(topK)
                 .outFields(Lists.newArrayList("pk", "summary", "content"))
                 .build();
@@ -225,32 +226,66 @@ public class ArticleMilvusService extends AbstractMilvusQueueService<SyncSearchE
      */
     @Override
     public boolean doWriteCollection(SyncSearchEngineDTO data) {
-        JsonObject vector = new JsonObject();
-        String summary = data.getSummary();
+
         String content = data.getContent();
-        // 对文章摘要做embed
-        float[] summaryVector = super.embed(summary);
-        float[] contentVector = super.embed(content);
-        /**
-         * 原数据写入
-         */
-        vector.addProperty("pk", data.getPk());
-        vector.addProperty("source_type", data.getSourceType());
-        vector.addProperty("summary", data.getSummary());
-        vector.addProperty("content", data.getContent());
+        // 正文html长富文本切片
+        List<HtmlSplitter.SplitResult> contentObjs = HtmlSplitter.split(content, true);
+        if (CollectionUtils.isEmpty(contentObjs)) {
+            return false;
+        }
+        String pk = data.getPk();
+        String sourceType = data.getSourceType();
+        String summary = data.getSummary();
+
+        List<String> contents = contentObjs.stream().map(HtmlSplitter.SplitResult::getContent).toList();
+        List<String> summaries = contentObjs.stream().map(c ->
+                                c.getTitle().concat("|").concat(data.getSummary())).toList();
+
+        // 批量embedding
+        List<float[]> contentVectors = super.embed(contents);
+        List<float[]> summaryVectors = super.embed(summaries);
+        List<JsonObject> vectors = new ArrayList<>();
+        for (int i = 0; i < contentVectors.size(); i++) {
+            vectors.add(
+                    buildMilvusData(
+                            pk,
+                            sourceType,
+                            summaries.get(i),
+                            contents.get(i),
+                            summaryVectors.get(i),
+                            contentVectors.get(i)));
+        }
+
         /**
-         * 向量字段处理
+         * content原文章内容切片后批量写
          */
-        Gson gson = new Gson();
-        vector.add("summary_vector", gson.toJsonTree(summaryVector));
-        vector.add("content_vector", gson.toJsonTree(contentVector));
         InsertReq request = InsertReq.builder()
                 .collectionName(collectionName())
-                .data(Collections.singletonList(vector))
+                .data(vectors)
                 .build();
         log.info("文章向量数据入库>>>> pk:{}", data.getPk());
         InsertResp resp = client.insert(request);
         log.info("文章向量数据入库完成>>>> pk:{}, resp:{}", data.getPk(), JsonUtil.toJsonString(resp));
         return true;
     }
+
+    private JsonObject buildMilvusData(String pk, String sourceType,
+                                       String summary, String content,
+                                       float[] summaryVector, float[] contentVector) {
+        JsonObject vector = new JsonObject();
+        /**
+         * 原数据写入
+         */
+        vector.addProperty("pk", pk);
+        vector.addProperty("source_type", sourceType);
+        vector.addProperty("summary", summary);
+        vector.addProperty("content", content);
+        /**
+         * 向量字段处理
+         */
+        Gson gson = new Gson();
+        vector.add("summary_vector", gson.toJsonTree(summaryVector));
+        vector.add("content_vector", gson.toJsonTree(contentVector));
+        return vector;
+    }
 }

Некоторые файлы не были показаны из-за большого количества измененных файлов