|
@@ -0,0 +1,256 @@
|
|
|
+package com.webchat.search.service.voctor;
|
|
|
+
|
|
|
+import com.google.common.collect.Lists;
|
|
|
+import com.google.gson.Gson;
|
|
|
+import com.google.gson.JsonObject;
|
|
|
+import com.webchat.common.constants.MilvusConstants;
|
|
|
+import com.webchat.common.enums.EmbeddingModelEnum;
|
|
|
+import com.webchat.common.util.JsonUtil;
|
|
|
+import com.webchat.domain.dto.search.SyncSearchEngineDTO;
|
|
|
+import com.webchat.domain.dto.search.SyncSearchEngineListDTO;
|
|
|
+import com.webchat.domain.vo.response.search.ArticleMilvusSearchResponse;
|
|
|
+import io.milvus.v2.common.DataType;
|
|
|
+import io.milvus.v2.common.IndexParam;
|
|
|
+import io.milvus.v2.service.collection.request.AddFieldReq;
|
|
|
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
|
|
|
+import io.milvus.v2.service.vector.request.AnnSearchReq;
|
|
|
+import io.milvus.v2.service.vector.request.HybridSearchReq;
|
|
|
+import io.milvus.v2.service.vector.request.InsertReq;
|
|
|
+import io.milvus.v2.service.vector.request.data.FloatVec;
|
|
|
+import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
|
|
|
+import io.milvus.v2.service.vector.response.InsertResp;
|
|
|
+import io.milvus.v2.service.vector.response.SearchResp;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.apache.commons.collections4.CollectionUtils;
|
|
|
+import org.springframework.stereotype.Service;
|
|
|
+
|
|
|
+import java.lang.reflect.Array;
|
|
|
+import java.util.Arrays;
|
|
|
+import java.util.Collections;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.concurrent.ArrayBlockingQueue;
|
|
|
+
|
|
|
+
|
|
|
+@Slf4j
|
|
|
+@Service
|
|
|
+public class ArticleMilvusService extends AbstractMilvusQueueService<SyncSearchEngineListDTO, ArticleMilvusSearchResponse, SyncSearchEngineDTO> {
|
|
|
+
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 本地队列初始长度
|
|
|
+ */
|
|
|
+ private static final int CAPACITY = 2000;
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 本地队列,用于同步数据入向量库
|
|
|
+ */
|
|
|
+ public ArrayBlockingQueue<SyncSearchEngineDTO> queue = new ArrayBlockingQueue(CAPACITY);
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 纬度,跟embedding模型纬度保持一致
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ @Override
|
|
|
+ protected int dimension() {
|
|
|
+ return EmbeddingModelEnum.ALIBABA.getDimension();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 集合定义
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ @Override
|
|
|
+ protected String collectionName() {
|
|
|
+
|
|
|
+ return MilvusConstants.CollectionName.COLLECTION_ARTICLE.name();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * schema 定义
|
|
|
+ *
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ @Override
|
|
|
+ protected CreateCollectionReq.CollectionSchema schema() {
|
|
|
+
|
|
|
+ CreateCollectionReq.CollectionSchema schema = client.createSchema();
|
|
|
+ schema.addField(AddFieldReq.builder()
|
|
|
+ .fieldName("id")
|
|
|
+ .dataType(DataType.VarChar)
|
|
|
+ .isPrimaryKey(true)
|
|
|
+ .autoID(true)
|
|
|
+ .build());
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 业务字段
|
|
|
+ */
|
|
|
+ schema.addField(AddFieldReq.builder()
|
|
|
+ .fieldName("pk")
|
|
|
+ .dataType(DataType.VarChar)
|
|
|
+ .maxLength(64)
|
|
|
+ .build());
|
|
|
+ schema.addField(AddFieldReq.builder()
|
|
|
+ .fieldName("source_type")
|
|
|
+ .dataType(DataType.VarChar)
|
|
|
+ .maxLength(64)
|
|
|
+ .build());
|
|
|
+ schema.addField(AddFieldReq.builder()
|
|
|
+ .fieldName("summary")
|
|
|
+ .dataType(DataType.VarChar)
|
|
|
+ .maxLength(500)
|
|
|
+ .build());
|
|
|
+ schema.addField(AddFieldReq.builder()
|
|
|
+ .fieldName("content")
|
|
|
+ .dataType(DataType.VarChar)
|
|
|
+ .maxLength(65535)
|
|
|
+ .build());
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 向量字段
|
|
|
+ */
|
|
|
+ schema.addField(AddFieldReq.builder()
|
|
|
+ .fieldName("summary_vector")
|
|
|
+ .dataType(DataType.FloatVector)
|
|
|
+ .dimension(this.dimension())
|
|
|
+ .build());
|
|
|
+ schema.addField(AddFieldReq.builder()
|
|
|
+ .fieldName("content_vector")
|
|
|
+ .dataType(DataType.FloatVector)
|
|
|
+ .dimension(this.dimension())
|
|
|
+ .build());
|
|
|
+ return schema;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 索引定义
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ @Override
|
|
|
+ protected List<IndexParam> indexParams() {
|
|
|
+
|
|
|
+ IndexParam summaryVectorIndex = IndexParam.builder()
|
|
|
+ .fieldName("summary_vector")
|
|
|
+ // 余弦相似度
|
|
|
+ .metricType(IndexParam.MetricType.COSINE)
|
|
|
+ .build();
|
|
|
+ IndexParam contentVectorIndex = IndexParam.builder()
|
|
|
+ .fieldName("content_vector")
|
|
|
+ // 余弦相似度
|
|
|
+ .metricType(IndexParam.MetricType.COSINE)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ return List.of(summaryVectorIndex, contentVectorIndex);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 入队业务逻辑实现,实际入队由父级putTaskQueue完成。
|
|
|
+ * @param syncData
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ @Override
|
|
|
+ protected boolean addTaskQueue(SyncSearchEngineListDTO syncData) {
|
|
|
+ if (CollectionUtils.isEmpty(syncData.getDataList())) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ // 文章同步任务加入队列
|
|
|
+ syncData.getDataList().forEach(super::putTaskQueue);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 基于相似度的topK条数据搜索
|
|
|
+ *
|
|
|
+ * 混合多向量搜索
|
|
|
+ *
|
|
|
+ * @param query
|
|
|
+ * @param topK
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ @Override
|
|
|
+ public List<ArticleMilvusSearchResponse> search(String query, int topK, float score) {
|
|
|
+
|
|
|
+ // 对用户query做embedding:这里embedding模型需要同同原数据embed采用同模型
|
|
|
+ float[] queryVector = super.embed(query);
|
|
|
+ AnnSearchReq summaryReq = AnnSearchReq.builder()
|
|
|
+ // 采用相似度检索算法:余弦相似度
|
|
|
+ .metricType(IndexParam.MetricType.COSINE)
|
|
|
+ .vectorFieldName("summary_vector")
|
|
|
+ .topK(topK)
|
|
|
+ .vectors(Collections.singletonList(new FloatVec(queryVector)))
|
|
|
+ .build();
|
|
|
+ AnnSearchReq contentReq = AnnSearchReq.builder()
|
|
|
+ // 采用相似度检索算法:余弦相似度
|
|
|
+ .metricType(IndexParam.MetricType.COSINE)
|
|
|
+ .vectorFieldName("content_vector")
|
|
|
+ .topK(topK)
|
|
|
+ .vectors(Collections.singletonList(new FloatVec(queryVector)))
|
|
|
+ .build();
|
|
|
+ HybridSearchReq request = HybridSearchReq.builder()
|
|
|
+ .collectionName(this.collectionName())
|
|
|
+ .searchRequests(Lists.newArrayList(summaryReq, contentReq))
|
|
|
+ .ranker(new WeightedRanker(Arrays.asList(0.6f, 0.4f)))
|
|
|
+ .topK(topK)
|
|
|
+ .outFields(Lists.newArrayList("pk", "summary", "content"))
|
|
|
+ .build();
|
|
|
+ SearchResp searchResp = client.hybridSearch(request);
|
|
|
+ List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
|
|
|
+ List<SearchResp.SearchResult> queryResults;
|
|
|
+ if (CollectionUtils.isEmpty(searchResults) || CollectionUtils.isEmpty(queryResults = searchResults.get(0))) {
|
|
|
+ return Collections.emptyList();
|
|
|
+ }
|
|
|
+ return queryResults.stream()
|
|
|
+ .filter(r -> r.getScore() > score)
|
|
|
+ .map(r -> {
|
|
|
+ Map<String, Object> resultMap = r.getEntity();
|
|
|
+ return new ArticleMilvusSearchResponse(Long.valueOf(String.valueOf(resultMap.get("pk"))),
|
|
|
+ String.valueOf(resultMap.getOrDefault("summary", "")),
|
|
|
+ String.valueOf(resultMap.getOrDefault("content", "")));
|
|
|
+ }).toList();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 队列
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ @Override
|
|
|
+ protected ArrayBlockingQueue<SyncSearchEngineDTO> queue() {
|
|
|
+ return queue;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 真正的数据写入实现
|
|
|
+ * @param data
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ @Override
|
|
|
+ public boolean doWriteCollection(SyncSearchEngineDTO data) {
|
|
|
+ JsonObject vector = new JsonObject();
|
|
|
+ String summary = data.getSummary();
|
|
|
+ String content = data.getContent();
|
|
|
+ // 对文章摘要做embed
|
|
|
+ float[] summaryVector = super.embed(summary);
|
|
|
+ float[] contentVector = super.embed(content);
|
|
|
+ /**
|
|
|
+ * 原数据写入
|
|
|
+ */
|
|
|
+ vector.addProperty("pk", data.getPk());
|
|
|
+ vector.addProperty("source_type", data.getSourceType());
|
|
|
+ vector.addProperty("summary", data.getSummary());
|
|
|
+ vector.addProperty("content", data.getContent());
|
|
|
+ /**
|
|
|
+ * 向量字段处理
|
|
|
+ */
|
|
|
+ Gson gson = new Gson();
|
|
|
+ vector.add("summary_vector", gson.toJsonTree(summaryVector));
|
|
|
+ vector.add("content_vector", gson.toJsonTree(contentVector));
|
|
|
+ InsertReq request = InsertReq.builder()
|
|
|
+ .collectionName(collectionName())
|
|
|
+ .data(Collections.singletonList(vector))
|
|
|
+ .build();
|
|
|
+ log.info("文章向量数据入库>>>> pk:{}", data.getPk());
|
|
|
+ InsertResp resp = client.insert(request);
|
|
|
+ log.info("文章向量数据入库完成>>>> pk:{}, resp:{}", data.getPk(), JsonUtil.toJsonString(resp));
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+}
|