前面我们通过两篇文章: BGE M3-Embedding 模型介绍Sparse稀疏检索介绍与实践 介绍了sparse 稀疏检索,今天我们来看看如何建立一个工程化的系统来实现sparse vec的检索。

之前提过milvus最新的V2.4支持sparse检索,我们先看看milvus的实现。

milvus的sparse检索实现

milvus 检索底层引擎是knowhere,主要代码在src/index/sparse 里。

首先,通过数据结构SparseRow,用于表示稀疏向量,支持浮点数(float)类型的数据

  1. class SparseRow {
  2. static_assert(std::is_same_v<T, fp32>, "SparseRow supports float only");
  3. public:
  4. // construct an SparseRow with memory allocated to hold `count` elements.
  5. SparseRow(size_t count = 0)
  6. : data_(count ? new uint8_t[count * element_size()] : nullptr), count_(count), own_data_(true) {
  7. }
  8. SparseRow(size_t count, uint8_t* data, bool own_data) : data_(data), count_(count), own_data_(own_data) {
  9. }
  10. // copy constructor and copy assignment operator perform deep copy
  11. SparseRow(const SparseRow<T>& other) : SparseRow(other.count_) {
  12. std::memcpy(data_, other.data_, data_byte_size());
  13. }
  14. SparseRow(SparseRow<T>&& other) noexcept : SparseRow() {
  15. swap(*this, other);
  16. }
  17. SparseRow&
  18. operator=(const SparseRow<T>& other) {
  19. if (this != &other) {
  20. SparseRow<T> tmp(other);
  21. swap(*this, tmp);
  22. }
  23. return *this;
  24. }
  25. SparseRow&
  26. operator=(SparseRow<T>&& other) noexcept {
  27. swap(*this, other);
  28. return *this;
  29. }
  30. ~SparseRow() {
  31. if (own_data_ && data_ != nullptr) {
  32. delete[] data_;
  33. data_ = nullptr;
  34. }
  35. }
  36. size_t
  37. size() const {
  38. return count_;
  39. }
  40. size_t
  41. memory_usage() const {
  42. return data_byte_size() + sizeof(*this);
  43. }
  44. // return the number of bytes used by the underlying data array.
  45. size_t
  46. data_byte_size() const {
  47. return count_ * element_size();
  48. }
  49. void*
  50. data() {
  51. return data_;
  52. }
  53. const void*
  54. data() const {
  55. return data_;
  56. }
  57. // dim of a sparse vector is the max index + 1, or 0 for an empty vector.
  58. int64_t
  59. dim() const {
  60. if (count_ == 0) {
  61. return 0;
  62. }
  63. auto* elem = reinterpret_cast<const ElementProxy*>(data_) + count_ - 1;
  64. return elem->index + 1;
  65. }
  66. SparseIdVal<T>
  67. operator[](size_t i) const {
  68. auto* elem = reinterpret_cast<const ElementProxy*>(data_) + i;
  69. return {elem->index, elem->value};
  70. }
  71. void
  72. set_at(size_t i, table_t index, T value) {
  73. auto* elem = reinterpret_cast<ElementProxy*>(data_) + i;
  74. elem->index = index;
  75. elem->value = value;
  76. }
  77. float
  78. dot(const SparseRow<T>& other) const {
  79. float product_sum = 0.0f;
  80. size_t i = 0;
  81. size_t j = 0;
  82. // TODO: improve with _mm_cmpistrm or the AVX512 alternative.
  83. while (i < count_ && j < other.count_) {
  84. auto* left = reinterpret_cast<const ElementProxy*>(data_) + i;
  85. auto* right = reinterpret_cast<const ElementProxy*>(other.data_) + j;
  86. if (left->index < right->index) {
  87. ++i;
  88. } else if (left->index > right->index) {
  89. ++j;
  90. } else {
  91. product_sum += left->value * right->value;
  92. ++i;
  93. ++j;
  94. }
  95. }
  96. return product_sum;
  97. }
  98. friend void
  99. swap(SparseRow<T>& left, SparseRow<T>& right) {
  100. using std::swap;
  101. swap(left.count_, right.count_);
  102. swap(left.data_, right.data_);
  103. swap(left.own_data_, right.own_data_);
  104. }
  105. static inline size_t
  106. element_size() {
  107. return sizeof(table_t) + sizeof(T);
  108. }
  109. private:
  110. // ElementProxy is used to access elements in the data_ array and should
  111. // never be actually constructed.
  112. struct __attribute__((packed)) ElementProxy {
  113. table_t index;
  114. T value;
  115. ElementProxy() = delete;
  116. ElementProxy(const ElementProxy&) = delete;
  117. };
  118. // data_ must be sorted by column id. use raw pointer for easy mmap and zero
  119. // copy.
  120. uint8_t* data_;
  121. size_t count_;
  122. bool own_data_;
  123. };

然后索引具体是在InvertedIndex 类里, 对应sparse_inverted_index.h 文件,首先看定义的一些private 字段。

  1. std::vector<SparseRow<T>> raw_data_;
  2. mutable std::shared_mutex mu_;
  3. std::unordered_map<table_t, std::vector<SparseIdVal<T>>> inverted_lut_;
  4. bool use_wand_ = false;
  5. // If we want to drop small values during build, we must first train the
  6. // index with all the data to compute value_threshold_.
  7. bool drop_during_build_ = false;
  8. // when drop_during_build_ is true, any value smaller than value_threshold_
  9. // will not be added to inverted_lut_. value_threshold_ is set to the
  10. // drop_ratio_build-th percentile of all absolute values in the index.
  11. T value_threshold_ = 0.0f;
  12. std::unordered_map<table_t, T> max_in_dim_;
  13. size_t max_dim_ = 0;
  • raw_data_ 是原始的数据
  • inverted_lut_ 可以理解为一个倒排表
  • use_wand_ 用于控制查询时,是否使用WAND算法,WAND算法是经典的查询优化算法,可以通过类似跳表的方式跳过一些数据,减少计算量,提升查询效率
  • max_in_dim_ 是为wand服务的

索引构建流程

构建,主要是对外提供一个Add数据的方法:

  1. Status
  2. Add(const SparseRow<T>* data, size_t rows, int64_t dim) {
  3. std::unique_lock<std::shared_mutex> lock(mu_);
  4. auto current_rows = n_rows_internal();
  5. if (current_rows > 0 && drop_during_build_) {
  6. LOG_KNOWHERE_ERROR_ << "Not allowed to add data to a built index with drop_ratio_build > 0.";
  7. return Status::invalid_args;
  8. }
  9. if ((size_t)dim > max_dim_) {
  10. max_dim_ = dim;
  11. }
  12. raw_data_.insert(raw_data_.end(), data, data + rows);
  13. for (size_t i = 0; i < rows; ++i) {
  14. add_row_to_index(data[i], current_rows + i);
  15. }
  16. return Status::success;
  17. }

这里会更新数据的max_dim,数据追加到raw_data_,然后add_row_to_index,将新的doc放入inverted_lut_, 并更新max_in_dim_,用于记录最大值,方便wand查询时跳过计算。

  1. inline void
  2. add_row_to_index(const SparseRow<T>& row, table_t id) {
  3. for (size_t j = 0; j < row.size(); ++j) {
  4. auto [idx, val] = row[j];
  5. // Skip values close enough to zero(which contributes little to
  6. // the total IP score).
  7. if (drop_during_build_ && fabs(val) < value_threshold_) {
  8. continue;
  9. }
  10. if (inverted_lut_.find(idx) == inverted_lut_.end()) {
  11. inverted_lut_[idx];
  12. if (use_wand_) {
  13. max_in_dim_[idx] = 0;
  14. }
  15. }
  16. inverted_lut_[idx].emplace_back(id, val);
  17. if (use_wand_) {
  18. max_in_dim_[idx] = std::max(max_in_dim_[idx], val);
  19. }
  20. }
  21. }

索引保存与load

保存时,是自定义的二进制文件:

  1. Status
  2. Save(MemoryIOWriter& writer) {
  3. /**
  4. * zero copy is not yet implemented, now serializing in a zero copy
  5. * compatible way while still copying during deserialization.
  6. *
  7. * Layout:
  8. *
  9. * 1. int32_t rows, sign indicates whether to use wand
  10. * 2. int32_t cols
  11. * 3. for each row:
  12. * 1. int32_t len
  13. * 2. for each non-zero value:
  14. * 1. table_t idx
  15. * 2. T val
  16. * With zero copy deserization, each SparseRow object should
  17. * reference(not owning) the memory address of the first element.
  18. *
  19. * inverted_lut_ and max_in_dim_ not serialized, they will be
  20. * constructed dynamically during deserialization.
  21. *
  22. * Data are densly packed in serialized bytes and no padding is added.
  23. */
  24. std::shared_lock<std::shared_mutex> lock(mu_);
  25. writeBinaryPOD(writer, n_rows_internal() * (use_wand_ ? 1 : -1));
  26. writeBinaryPOD(writer, n_cols_internal());
  27. writeBinaryPOD(writer, value_threshold_);
  28. for (size_t i = 0; i < n_rows_internal(); ++i) {
  29. auto& row = raw_data_[i];
  30. writeBinaryPOD(writer, row.size());
  31. if (row.size() == 0) {
  32. continue;
  33. }
  34. writer.write(row.data(), row.size() * SparseRow<T>::element_size());
  35. }
  36. return Status::success;
  37. }

索引文件格式:

    1. int32_t rows 总记录数,通过±符号来区分是否 use wand
    1. int32_t cols 列数
    1. for each row:
    1. 1. int32_t len 长度
    1. 2. for each non-zero value:
    1. 1. table_t idx termid编号
    1. 2. T val term的权重

注意,这里inverted_lut_倒排表是没有存储的,是在加载的时候重建,所以load的过程,就是一个逆过程:

  1. Status
  2. Load(MemoryIOReader& reader) {
  3. std::unique_lock<std::shared_mutex> lock(mu_);
  4. int64_t rows;
  5. readBinaryPOD(reader, rows);
  6. use_wand_ = rows > 0;
  7. rows = std::abs(rows);
  8. readBinaryPOD(reader, max_dim_);
  9. readBinaryPOD(reader, value_threshold_);
  10. raw_data_.reserve(rows);
  11. for (int64_t i = 0; i < rows; ++i) {
  12. size_t count;
  13. readBinaryPOD(reader, count);
  14. raw_data_.emplace_back(count);
  15. if (count == 0) {
  16. continue;
  17. }
  18. reader.read(raw_data_[i].data(), count * SparseRow<T>::element_size());
  19. add_row_to_index(raw_data_[i], i);
  20. }
  21. return Status::success;
  22. }

检索流程

我们来回顾,compute_lexical_matching_score其实就是计算共同term的weight score相乘,然后加起来,所以可以想象下,暴力检索大概就是把所有term对应的doc取并集,然后计算lexical_matching_score,最后取topk。

我们来看milvus的实现,先看暴力检索:

  1. // find the top-k candidates using brute force search, k as specified by the capacity of the heap.
  2. // any value in q_vec that is smaller than q_threshold and any value with dimension >= n_cols() will be ignored.
  3. // TODO: may switch to row-wise brute force if filter rate is high. Benchmark needed.
  4. void
  5. search_brute_force(const SparseRow<T>& q_vec, T q_threshold, MaxMinHeap<T>& heap, const BitsetView& bitset) const {
  6. auto scores = compute_all_distances(q_vec, q_threshold);
  7. for (size_t i = 0; i < n_rows_internal(); ++i) {
  8. if ((bitset.empty() || !bitset.test(i)) && scores[i] != 0) {
  9. heap.push(i, scores[i]);
  10. }
  11. }
  12. }
  13. std::vector<float>
  14. compute_all_distances(const SparseRow<T>& q_vec, T q_threshold) const {
  15. std::vector<float> scores(n_rows_internal(), 0.0f);
  16. for (size_t idx = 0; idx < q_vec.size(); ++idx) {
  17. auto [i, v] = q_vec[idx];
  18. if (v < q_threshold || i >= n_cols_internal()) {
  19. continue;
  20. }
  21. auto lut_it = inverted_lut_.find(i);
  22. if (lut_it == inverted_lut_.end()) {
  23. continue;
  24. }
  25. // TODO: improve with SIMD
  26. auto& lut = lut_it->second;
  27. for (size_t j = 0; j < lut.size(); j++) {
  28. auto [idx, val] = lut[j];
  29. scores[idx] += v * float(val);
  30. }
  31. }
  32. return scores;
  33. }
  • 核心在compute_all_distances里,先通过q_vec得到每一个term id,然后从inverted_lut_里找到term对应的doc list,然后计算score,相同doc id的score累加
  • 最后用MaxMinHeap堆,来取topk

暴力检索能保准精准性,但是效率比较低。我们来看使用wand优化的检索:

  1. // any value in q_vec that is smaller than q_threshold will be ignored.
  2. void
  3. search_wand(const SparseRow<T>& q_vec, T q_threshold, MaxMinHeap<T>& heap, const BitsetView& bitset) const {
  4. auto q_dim = q_vec.size();
  5. std::vector<std::shared_ptr<Cursor<std::vector<SparseIdVal<T>>>>> cursors(q_dim);
  6. auto valid_q_dim = 0;
  7. // 倒排链
  8. for (size_t i = 0; i < q_dim; ++i) {
  9. // idx(term_id)
  10. auto [idx, val] = q_vec[i];
  11. if (std::abs(val) < q_threshold || idx >= n_cols_internal()) {
  12. continue;
  13. }
  14. auto lut_it = inverted_lut_.find(idx);
  15. if (lut_it == inverted_lut_.end()) {
  16. continue;
  17. }
  18. auto& lut = lut_it->second;
  19. // max_in_dim_ 记录了term index 的最大score
  20. cursors[valid_q_dim++] = std::make_shared<Cursor<std::vector<SparseIdVal<T>>>>(
  21. lut, n_rows_internal(), max_in_dim_.find(idx)->second * val, val, bitset);
  22. }
  23. if (valid_q_dim == 0) {
  24. return;
  25. }
  26. cursors.resize(valid_q_dim);
  27. auto sort_cursors = [&cursors] {
  28. std::sort(cursors.begin(), cursors.end(),
  29. [](auto& x, auto& y) { return x->cur_vec_id() < y->cur_vec_id(); });
  30. };
  31. sort_cursors();
  32. // 堆未满,或者新的score > 堆顶的score
  33. auto score_above_threshold = [&heap](float x) { return !heap.full() || x > heap.top().val; };
  34. while (true) {
  35. // 上边界
  36. float upper_bound = 0;
  37. // pivot 满足条件的倒排链的序号
  38. size_t pivot;
  39. bool found_pivot = false;
  40. for (pivot = 0; pivot < cursors.size(); ++pivot) {
  41. // 有倒排结束
  42. if (cursors[pivot]->is_end()) {
  43. break;
  44. }
  45. upper_bound += cursors[pivot]->max_score();
  46. if (score_above_threshold(upper_bound)) {
  47. found_pivot = true;
  48. break;
  49. }
  50. }
  51. if (!found_pivot) {
  52. break;
  53. }
  54. // 找到满足upper_bound 满足条件的pivot_id
  55. table_t pivot_id = cursors[pivot]->cur_vec_id();
  56. // 如果第一个倒排链的当前vec_id (doc_id) 等于pivot_id,可以直接从第0个倒排链开始,计算score
  57. if (pivot_id == cursors[0]->cur_vec_id()) {
  58. float score = 0;
  59. // 遍历所有cursors,累加score
  60. for (auto& cursor : cursors) {
  61. if (cursor->cur_vec_id() != pivot_id) {
  62. break;
  63. }
  64. score += cursor->cur_distance() * cursor->q_value();
  65. // 倒排链移到下一位
  66. cursor->next();
  67. }
  68. // 放入堆
  69. heap.push(pivot_id, score);
  70. // 重排cursors,保证最小的vec_id在最前面
  71. sort_cursors();
  72. } else {
  73. // 第一个倒排链的当前vec_id不等于pivot_id, pivot>=1
  74. // 那么从pivot(满足threshold的倒排链序号)往前找是否有cur_vec_id==pivot_id的
  75. size_t next_list = pivot;
  76. for (; cursors[next_list]->cur_vec_id() == pivot_id; --next_list) {
  77. }
  78. // 这里的next_list的cur_vec_id 不一定等与pivot_id,将list seek到pivot_id
  79. // seek后,cursors[next_list].cur_vec_id() >= pivot_id,通过seek,可以跳过一些vec id
  80. cursors[next_list]->seek(pivot_id);
  81. // 从next_list + 1开始
  82. for (size_t i = next_list + 1; i < cursors.size(); ++i) {
  83. // 如果当前cur_vec_id >= 上一个则停止
  84. if (cursors[i]->cur_vec_id() >= cursors[i - 1]->cur_vec_id()) {
  85. break;
  86. }
  87. // 否则,交换倒排链,可以确保==pivot_id的倒排链交换到前面
  88. std::swap(cursors[i], cursors[i - 1]);
  89. }
  90. }
  91. }
  92. }
  • 首先是倒排链取出来放入cursors,然后对cursors按照vec_id排序,将vec_id较小的排到倒排链的首位
  • 通过score_above_threshold,遍历cursors找符合条件的cursor 索引号pivot,这里通过堆未满,或者新的score > 堆顶的score来判断,可以跳过一些score小的
  • 然后找到pivot cursor对应的pivot_id,也就是doc id,然后判断第一个倒排链的cur_vec_id 是否等于pivot_id:
    • 如果等于,就可以遍历倒排链,计算pivot_id的score,然后放入小顶堆中排序,然后重排倒排链
    • 如果不等于,那么就需要想办法将cur_vec_id == pivot_id的往前放,同时跳过倒排链中vec_id < cur_vec_id的数据(减枝)

用golang实现轻量级sparse vec检索

用类似milvus的方法,我们简单实现一个golang版本的

  1. package main
  2. import (
  3. "container/heap"
  4. "encoding/binary"
  5. "fmt"
  6. "io"
  7. "math/rand"
  8. "os"
  9. "sort"
  10. "time"
  11. )
  12. type Cursor struct {
  13. docIDs []int32
  14. weights []float64
  15. maxScore float64
  16. termWeight float64
  17. currentIdx int
  18. }
  19. func NewCursor(docIDs []int32, weights []float64, maxScore float64, weight float64) *Cursor {
  20. return &Cursor{
  21. docIDs: docIDs,
  22. weights: weights,
  23. maxScore: maxScore,
  24. termWeight: weight,
  25. currentIdx: 0,
  26. }
  27. }
  28. func (c *Cursor) Next() {
  29. c.currentIdx++
  30. }
  31. func (c *Cursor) Seek(docId int32) {
  32. for {
  33. if c.IsEnd() {
  34. break
  35. }
  36. if c.CurrentDocID() < docId {
  37. c.Next()
  38. } else {
  39. break
  40. }
  41. }
  42. }
  43. func (c *Cursor) IsEnd() bool {
  44. return c.currentIdx >= len(c.docIDs)
  45. }
  46. func (c *Cursor) CurrentDocID() int32 {
  47. return c.docIDs[c.currentIdx]
  48. }
  49. func (c *Cursor) CurrentDocWeight() float64 {
  50. return c.weights[c.currentIdx]
  51. }
  52. // DocVectors type will map docID to its vector
  53. type DocVectors map[int32]map[int32]float64
  54. // InvertedIndex type will map termID to sorted list of docIDs
  55. type InvertedIndex map[int32][]int32
  56. // TermMaxScore will keep track of maximum scores for terms
  57. type TermMaxScores map[int32]float64
  58. // SparseIndex class struct
  59. type SparseIndex struct {
  60. docVectors DocVectors
  61. invertedIndex InvertedIndex
  62. termMaxScores TermMaxScores
  63. dim int32
  64. }
  65. // NewSparseIndex initializes a new SparseIndex with empty structures
  66. func NewSparseIndex() *SparseIndex {
  67. return &SparseIndex{
  68. docVectors: make(DocVectors),
  69. invertedIndex: make(InvertedIndex),
  70. termMaxScores: make(TermMaxScores),
  71. dim: 0,
  72. }
  73. }
  74. // Add method for adding documents to the sparse index
  75. func (index *SparseIndex) Add(docID int32, vec map[int32]float64) {
  76. index.docVectors[docID] = vec
  77. for termID, score := range vec {
  78. index.invertedIndex[termID] = append(index.invertedIndex[termID], docID)
  79. // Track max score for each term
  80. if maxScore, ok := index.termMaxScores[termID]; !ok || score > maxScore {
  81. index.termMaxScores[termID] = score
  82. }
  83. if termID > index.dim {
  84. index.dim = termID
  85. }
  86. }
  87. }
  88. // Save index to file
  89. func (index *SparseIndex) Save(filename string) error {
  90. file, err := os.Create(filename)
  91. if err != nil {
  92. return err
  93. }
  94. defer file.Close()
  95. // Write the dimension
  96. binary.Write(file, binary.LittleEndian, index.dim)
  97. // Write each document vector
  98. for docID, vec := range index.docVectors {
  99. binary.Write(file, binary.LittleEndian, docID)
  100. vecSize := int32(len(vec))
  101. binary.Write(file, binary.LittleEndian, vecSize)
  102. for termID, score := range vec {
  103. binary.Write(file, binary.LittleEndian, termID)
  104. binary.Write(file, binary.LittleEndian, score)
  105. }
  106. }
  107. return nil
  108. }
  109. // Load index from file
  110. func (index *SparseIndex) Load(filename string) error {
  111. file, err := os.Open(filename)
  112. if err != nil {
  113. return err
  114. }
  115. defer file.Close()
  116. var dim int32
  117. binary.Read(file, binary.LittleEndian, &dim)
  118. index.dim = dim
  119. for {
  120. var docID int32
  121. err := binary.Read(file, binary.LittleEndian, &docID)
  122. if err == io.EOF {
  123. break // End of file
  124. } else if err != nil {
  125. return err // Some other error
  126. }
  127. var vecSize int32
  128. binary.Read(file, binary.LittleEndian, &vecSize)
  129. vec := make(map[int32]float64)
  130. for i := int32(0); i < vecSize; i++ {
  131. var termID int32
  132. var score float64
  133. binary.Read(file, binary.LittleEndian, &termID)
  134. binary.Read(file, binary.LittleEndian, &score)
  135. vec[termID] = score
  136. }
  137. index.Add(docID, vec) // Rebuild the index
  138. }
  139. return nil
  140. }
  141. func (index *SparseIndex) bruteSearch(queryVec map[int32]float64, K int) []int32 {
  142. scores := computeAllDistances(queryVec, index)
  143. // 取top k
  144. docHeap := &DocScoreHeap{}
  145. for docID, score := range scores {
  146. if docHeap.Len() < K {
  147. heap.Push(docHeap, &DocScore{docID, score})
  148. } else if (*docHeap)[0].score < score {
  149. heap.Pop(docHeap)
  150. heap.Push(docHeap, &DocScore{docID, score})
  151. }
  152. }
  153. topDocs := make([]int32, 0, K)
  154. for docHeap.Len() > 0 {
  155. el := heap.Pop(docHeap).(*DocScore)
  156. topDocs = append(topDocs, el.docID)
  157. }
  158. sort.Slice(topDocs, func(i, j int) bool {
  159. return topDocs[i] < topDocs[j]
  160. })
  161. return topDocs
  162. }
  163. func computeAllDistances(queryVec map[int32]float64, index *SparseIndex) map[int32]float64 {
  164. scores := make(map[int32]float64)
  165. for term, qWeight := range queryVec {
  166. if postingList, exists := index.invertedIndex[term]; exists {
  167. for _, docID := range postingList {
  168. docVec := index.docVectors[docID]
  169. docWeight, exists := docVec[term]
  170. if !exists {
  171. continue
  172. }
  173. score := qWeight * docWeight
  174. if _, ok := scores[docID]; !ok {
  175. scores[docID] = score
  176. } else {
  177. scores[docID] += score
  178. }
  179. }
  180. }
  181. }
  182. return scores
  183. }
  184. // TopK retrieves the top K documents nearest to the query vector
  185. func (index *SparseIndex) WandSearch(queryVec map[int32]float64, K int) []int32 {
  186. docHeap := &DocScoreHeap{}
  187. // 倒排链
  188. postingLists := make([]*Cursor, len(queryVec))
  189. idx := 0
  190. for term, termWeight := range queryVec {
  191. if postingList, exists := index.invertedIndex[term]; exists {
  192. // 包含term的doc,term对应的weight
  193. weights := make([]float64, len(postingList))
  194. for i, docID := range postingList {
  195. weights[i] = index.docVectors[docID][term]
  196. }
  197. postingLists[idx] = NewCursor(postingList, weights, index.termMaxScores[term]*termWeight, termWeight)
  198. idx += 1
  199. }
  200. }
  201. sortPostings := func() {
  202. for i := range postingLists {
  203. if postingLists[i].IsEnd() {
  204. return
  205. }
  206. }
  207. // 将postingLists按照首个docid排序
  208. sort.Slice(postingLists, func(i, j int) bool {
  209. return postingLists[i].CurrentDocID() < postingLists[j].CurrentDocID()
  210. })
  211. }
  212. sortPostings()
  213. scoreAboveThreshold := func(value float64) bool {
  214. return docHeap.Len() < K || (*docHeap)[0].score < value
  215. }
  216. for {
  217. upperBound := 0.0
  218. foundPivot := false
  219. pivot := 0
  220. for idx := range postingLists {
  221. if postingLists[idx].IsEnd() {
  222. break
  223. }
  224. upperBound += postingLists[idx].maxScore
  225. if scoreAboveThreshold(upperBound) {
  226. foundPivot = true
  227. pivot = idx
  228. break
  229. }
  230. }
  231. if !foundPivot {
  232. break
  233. }
  234. // 找到满足upper_bound 满足条件的pivot_id
  235. pivotId := postingLists[pivot].CurrentDocID()
  236. if pivotId == postingLists[0].CurrentDocID() {
  237. // 如果第一个倒排链的当前vec_id (doc_id) 等于pivot_id,可以直接从第0个倒排链开始,计算score
  238. score := 0.0
  239. // 遍历所有cursors,累加score
  240. for idx := range postingLists {
  241. cursor := postingLists[idx]
  242. if cursor.CurrentDocID() != pivotId {
  243. break
  244. }
  245. score += cursor.CurrentDocWeight() * cursor.termWeight
  246. // 移到下一个docid
  247. postingLists[idx].Next()
  248. }
  249. // 放入堆s
  250. if docHeap.Len() < K {
  251. heap.Push(docHeap, &DocScore{pivotId, score})
  252. } else if (*docHeap)[0].score < score {
  253. heap.Pop(docHeap)
  254. heap.Push(docHeap, &DocScore{pivotId, score})
  255. }
  256. // 重排cursors,保证最小的vec_id在最前面
  257. sortPostings()
  258. } else {
  259. // 第一个倒排链的当前vec_id不等于pivot_id, pivot>=1
  260. // 那么从pivot(满足threshold的倒排链序号)往前找是否有cur_vec_id==pivot_id的
  261. nextList := pivot
  262. for ; postingLists[nextList].CurrentDocID() == pivotId; nextList-- {
  263. }
  264. // 这里的next_list的cur_vec_id 不一定等与pivot_id,将list seek到pivot_id
  265. // seek后,cursors[next_list].cur_vec_id() >= pivot_id,通过seek,可以跳过一些vec id
  266. postingLists[nextList].Seek(pivotId)
  267. // 从next_list + 1开始
  268. for i := nextList + 1; i < len(postingLists); i++ {
  269. // 如果当前cur_vec_id >= 上一个则停止
  270. if postingLists[i].CurrentDocID() >= postingLists[i-1].CurrentDocID() {
  271. break
  272. }
  273. // 否则,交换倒排链,可以确保==pivot_id的倒排链交换到前面
  274. temp := postingLists[i]
  275. postingLists[i] = postingLists[i-1]
  276. postingLists[i-1] = temp
  277. }
  278. }
  279. }
  280. topDocs := make([]int32, 0, K)
  281. for docHeap.Len() > 0 {
  282. el := heap.Pop(docHeap).(*DocScore)
  283. topDocs = append(topDocs, el.docID)
  284. }
  285. sort.Slice(topDocs, func(i, j int) bool {
  286. return topDocs[i] < topDocs[j]
  287. })
  288. return topDocs
  289. }
  290. // Helper structure to manage the priority queue for the top-K documents
  291. type DocScore struct {
  292. docID int32
  293. score float64
  294. }
  295. type DocScoreHeap []*DocScore
  296. func (h DocScoreHeap) Len() int { return len(h) }
  297. func (h DocScoreHeap) Less(i, j int) bool { return h[i].score < h[j].score }
  298. func (h DocScoreHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  299. func (h *DocScoreHeap) Push(x interface{}) {
  300. *h = append(*h, x.(*DocScore))
  301. }
  302. func (h *DocScoreHeap) Pop() interface{} {
  303. old := *h
  304. n := len(old)
  305. x := old[n-1]
  306. *h = old[0 : n-1]
  307. return x
  308. }
  309. func main() {
  310. index := NewSparseIndex()
  311. rand.Seed(time.Now().UnixNano())
  312. // Add document vectors as needed
  313. for i := 1; i <= 1000; i++ {
  314. // 打印当前i的值
  315. index.Add(int32(i), map[int32]float64{101: rand.Float64(),
  316. 150: rand.Float64(),
  317. 190: rand.Float64(),
  318. 500: rand.Float64()})
  319. }
  320. //index.Save("index.bin")
  321. //index.Load("index.bin")
  322. topDocs := index.WandSearch(map[int32]float64{101: rand.Float64(), 150: rand.Float64(), 190: rand.Float64(),
  323. 500: rand.Float64()}, 10)
  324. fmt.Println("Top Docs:", topDocs)
  325. }
  • 代码实现了索引的构建、保存和加载,检索方面实现了暴力检索和WAND检索
  • 注意,添加doc时,需要保障doc有序,实际应用中,docid可以引擎维护一个真实id到递增docid的映射
  • 代码中已经有注释,这里不再赘述,注意代码未充分调试,可能有bug
  • 代码实现倒排表全放到内存,效率高,但对内存要求高

总结

sparse 检索整体类似传统的文本检索,因此传统的工程优化方法可以运用到sparse检索中,本文分析了milvus的实现,并实现了一个golang版本的sparse检索。

浅谈sparse vec检索工程化实现的更多相关文章

  1. c#Winform程序调用app.config文件配置数据库连接字符串 SQL Server文章目录 浅谈SQL Server中统计对于查询的影响 有关索引的DMV SQL Server中的执行引擎入门 【译】表变量和临时表的比较 对于表列数据类型选择的一点思考 SQL Server复制入门(一)----复制简介 操作系统中的进程与线程

    c#Winform程序调用app.config文件配置数据库连接字符串 你新建winform项目的时候,会有一个app.config的配置文件,写在里面的<connectionStrings n ...

  2. 浅谈angular2+ionic2

    浅谈angular2+ionic2   前言: 不要用angular的语法去写angular2,有人说二者就像Java和JavaScript的区别.   1. 项目所用:angular2+ionic2 ...

  3. 浅谈Angular的 $q, defer, promise

    浅谈Angular的 $q, defer, promise 时间 2016-01-13 00:28:00  博客园-原创精华区 原文  http://www.cnblogs.com/big-snow/ ...

  4. 浅谈Hybrid技术的设计与实现

    前言 浅谈Hybrid技术的设计与实现 浅谈Hybrid技术的设计与实现第二弹 浅谈Hybrid技术的设计与实现第三弹——落地篇 随着移动浪潮的兴起,各种APP层出不穷,极速的业务扩展提升了团队对开发 ...

  5. python浅谈正则的常用方法

    python浅谈正则的常用方法覆盖范围70%以上 上一次很多朋友写文字屏蔽说到要用正则表达,其实不是我不想用(我正则用得不是很多,看过我之前爬虫的都知道,我直接用BeautifulSoup的网页标签去 ...

  6. 浅谈php生成静态页面

    一.引 言 在速度上,静态页面要比动态页面的比方php快很多,这是毫无疑问的,但是由于静态页面的灵活性较差,如果不借助数据库或其他的设备保存相关信息的话,整体的管理上比较繁琐,比方修改编辑.比方阅读权 ...

  7. 【转】Windows SDK入门浅谈

    前言 如果你是一个编程初学者,如果你刚刚结束C语言的课程.你可能会有点失望和怀疑:这就是C语言吗?靠它就能编出软件?无法想象Windows桌面上一个普通的窗口是怎样出现在眼前的.从C语言的上机作业到W ...

  8. 浅谈SQL优化入门:3、利用索引

    0.写在前面的话 关于索引的内容本来是想写的,大概收集了下资料,发现并没有想象中的简单,又不想总结了,纠结了一下,决定就大概写点浅显的,好吧,就是懒,先挖个浅坑,以后再挖深一点.最基本的使用很简单,直 ...

  9. 【ASP.NET MVC系列】浅谈NuGet在VS中的运用

    一     概述 在我们讲解NuGet前,我们先来看看一个例子. 1.例子: 假设现在开发一套系统,其中前端框架我们选择Bootstrap,由于选择Bootstrap作为前端框架,因此,在项目中,我们 ...

  10. 浅谈-RMQ

    浅谈RMQ Today,我get到了一个新算法,开心....RMQ. 今天主要说一下RMQ里的ST算法(Sparse Table). RMQ(Range Minimum/Maximum Query), ...

随机推荐

  1. terminate called after throwing an instance of 'std::regex_error'(C++11)

    PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明   本文作为本人csdn blog的主站的备份.(Bl ...

  2. slf4j 和 log4j2 架构设计

    1.日志框架背景 2.为什么会有 slf4j 和 log4j2 搭配一说? 3.log4j2 3.1.背景及应用场景 3.2.功能模块 4.slf4j 4.1.背景及应用场景 4.2.功能模块 5.s ...

  3. 【atcoder begin 302】【e题 Isolation 】JAVA的快速输入输出

    import java.io.*; import java.util.HashSet; import java.util.Set; /** * @author fishcanfly */ public ...

  4. (2)Python解释器的安装

    鉴于有同学在安装Python解释器出现了问题,这里再安装一下 step1,下载安装包,链接https://www.python.org/downloads/ 这里我安装的是3.6.4版本 我选择的是6 ...

  5. Python基于Excel生成矢量图层及属性表信息:ArcPy

      本文介绍基于Python中ArcPy模块,读取Excel表格数据并生成带有属性表的矢量要素图层,同时配置该图层的坐标系的方法. 1 任务需求   首先,我们来明确一下本文所需实现的需求.   现有 ...

  6. Unable to load library ‘xxx‘: 找不到指定的模块。找不到指定的模块。 Can‘t obtain InputStream for win32-x86-64/xxx.dll

    我使用的是 <dependency> <groupId>net.java.dev.jna</groupId> <artifactId>jna</a ...

  7. springboot整合视频点播

    1 //上传视频到阿里云 2 @Override 3 public String uploadAyl(MultipartFile file) { 4 try { 5 //accessKeyId,acc ...

  8. ZYNQ系列学习GPIO实验

    GPIO实验 一.实验原理 调用GPIO实现PS对引脚的控制 二.实验步骤 1.建立工程 这部分是ivado的操作内容,这里不做过多说明. 2.添加ZYNQ处理器IP 在左侧菜单栏中双击Create  ...

  9. KingbaseES V8R6集群运维系列 -- connect_check_type参数

    案例说明: 在KingbaseES V8R6C7的版本中,repmgr.conf增加了connect_check_type参数可以指定参数值,默认值是'mix',以前版本为隐藏参数,默认值'ping' ...

  10. 算法学习笔记【6】| KMP 算法

    KMP(Knuth-Morris-Pratt字符串查找算法) KMP 算法是可以快速在文本串 s 中找到模式串 a 的算法. Part 1:幼稚的算法 首先思考我们在暴力匹配模式串时的思路: < ...