从零构建一个(极其简单的)向量数据库¶
本教程将展示如何构建一个简单的内存向量数据库,它能够存储文档及其元数据。该系统还将提供支持多种查询方式的接口:
- 语义搜索(基于嵌入向量相似度)
- 元数据过滤
注意:这显然不能替代任何真正的向量数据库(例如Pinecone、Weaviate、Chroma、Qdrant、Milvus或我们众多向量数据库集成中的其他方案)。本教程更多是为了传授一些关键检索概念,比如基于嵌入向量的top-k搜索与元数据过滤。
我们不会涉及高级查询/检索概念,例如近似最近邻搜索、稀疏/混合搜索,也不会讨论构建实际数据库所需的任何系统设计概念。
安装配置¶
我们加载若干文档,并将其解析为节点对象(Node)——这些数据块已准备好插入向量存储中。
加载文档¶
%pip install llama-index-readers-file pymupdf
%pip install llama-index-embeddings-openai
!mkdir data
!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"
from pathlib import Path
from llama_index.readers.file import PyMuPDFReader
loader = PyMuPDFReader()
documents = loader.load(file_path="./data/llama2.pdf")
解析为节点¶
from llama_index.core.node_parser import SentenceSplitter
node_parser = SentenceSplitter(chunk_size=256)
nodes = node_parser.get_nodes_from_documents(documents)
为每个节点生成嵌入向量¶
from llama_index.embeddings.openai import OpenAIEmbedding
embed_model = OpenAIEmbedding()
for node in nodes:
node_embedding = embed_model.get_text_embedding(
node.get_content(metadata_mode="all")
)
node.embedding = node_embedding
构建简易内存向量存储¶
接下来我们将构建一个内存式向量存储。我们将使用简单的 Python 字典来存储节点数据。首先实现嵌入向量搜索功能,随后再添加元数据过滤支持。
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.core.vector_stores import (
VectorStoreQuery,
VectorStoreQueryResult,
)
from typing import List, Any, Optional, Dict
from llama_index.core.schema import TextNode, BaseNode
import os
class BaseVectorStore(BasePydanticVectorStore):
"""Simple custom Vector Store.
Stores documents in a simple in-memory dict.
"""
stores_text: bool = True
def get(self, text_id: str) -> List[float]:
"""Get embedding."""
pass
def add(
self,
nodes: List[BaseNode],
) -> List[str]:
"""Add nodes to index."""
pass
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
"""
pass
def query(
self,
query: VectorStoreQuery,
**kwargs: Any,
) -> VectorStoreQueryResult:
"""Get nodes for response."""
pass
def persist(self, persist_path, fs=None) -> None:
"""Persist the SimpleVectorStore to a directory.
NOTE: we are not implementing this for now.
"""
pass
从高层次来看,我们继承了基础的 VectorStore 抽象类。如果只是从零开始构建向量存储,这样做并没有内在必要性。我们采用继承方式主要是为了后续能轻松接入下游抽象层。
让我们看看这里定义的部分类:
BaseNode是我们核心节点模块的父类。每个节点代表一个文本块及其关联的元数据。- 我们还使用了一些底层构造,例如
VectorStoreQuery和VectorStoreQueryResult。这些轻量级的数据类容器仅用于表示查询和结果。具体字段定义如下所示。
from dataclasses import fields
{f.name: f.type for f in fields(VectorStoreQuery)}
{'query_embedding': typing.Optional[typing.List[float]],
'similarity_top_k': int,
'doc_ids': typing.Optional[typing.List[str]],
'node_ids': typing.Optional[typing.List[str]],
'query_str': typing.Optional[str],
'output_fields': typing.Optional[typing.List[str]],
'embedding_field': typing.Optional[str],
'mode': <enum 'VectorStoreQueryMode'>,
'alpha': typing.Optional[float],
'filters': typing.Optional[llama_index.vector_stores.types.MetadataFilters],
'mmr_threshold': typing.Optional[float],
'sparse_top_k': typing.Optional[int]}
{f.name: f.type for f in fields(VectorStoreQueryResult)}
{'nodes': typing.Optional[typing.Sequence[llama_index.schema.BaseNode]],
'similarities': typing.Optional[typing.List[float]],
'ids': typing.Optional[typing.List[str]]}
from llama_index.core.bridge.pydantic import Field
class VectorStore2(BaseVectorStore):
"""VectorStore2 (add/get/delete implemented)."""
stores_text: bool = True
node_dict: Dict[str, BaseNode] = Field(default_factory=dict)
def get(self, text_id: str) -> List[float]:
"""Get embedding."""
return self.node_dict[text_id]
def add(
self,
nodes: List[BaseNode],
) -> List[str]:
"""Add nodes to index."""
for node in nodes:
self.node_dict[node.node_id] = node
def delete(self, node_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with node_id.
Args:
node_id: str
"""
del self.node_dict[node_id]
我们运行了一些基础测试,以证明其运行良好。
test_node = TextNode(id_="id1", text="hello world")
test_node2 = TextNode(id_="id2", text="foo bar")
test_nodes = [test_node, test_node2]
vector_store = VectorStore2()
vector_store.add(test_nodes)
node = vector_store.get("id1")
print(str(node))
Node ID: id1 Text: hello world
3.a 定义 query(语义搜索)¶
我们实现了基础版的前k项语义搜索功能。该功能会遍历所有文档嵌入向量,计算其与查询嵌入向量的余弦相似度,最终返回余弦相似度最高的前k个文档。
余弦相似度计算公式:对于每个文档嵌入向量$\vec{d}$和查询嵌入向量$\vec{p}$,计算$\dfrac{\vec{d}\vec{q}}{|\vec{d}||\vec{q}|}$。
注意:前k项数值包含在VectorStoreQuery容器中。
注意:与上述情况类似,我们定义另一个子类仅是为了避免重复实现上述功能(并非因为这是良好的编码实践)。
from typing import Tuple
import numpy as np
def get_top_k_embeddings(
query_embedding: List[float],
doc_embeddings: List[List[float]],
doc_ids: List[str],
similarity_top_k: int = 5,
) -> Tuple[List[float], List]:
"""Get top nodes by similarity to the query."""
# dimensions: D
qembed_np = np.array(query_embedding)
# dimensions: N x D
dembed_np = np.array(doc_embeddings)
# dimensions: N
dproduct_arr = np.dot(dembed_np, qembed_np)
# dimensions: N
norm_arr = np.linalg.norm(qembed_np) * np.linalg.norm(
dembed_np, axis=1, keepdims=False
)
# dimensions: N
cos_sim_arr = dproduct_arr / norm_arr
# now we have the N cosine similarities for each document
# sort by top k cosine similarity, and return ids
tups = [(cos_sim_arr[i], doc_ids[i]) for i in range(len(doc_ids))]
sorted_tups = sorted(tups, key=lambda t: t[0], reverse=True)
sorted_tups = sorted_tups[:similarity_top_k]
result_similarities = [s for s, _ in sorted_tups]
result_ids = [n for _, n in sorted_tups]
return result_similarities, result_ids
from typing import cast
class VectorStore3A(VectorStore2):
"""Implements semantic/dense search."""
def query(
self,
query: VectorStoreQuery,
**kwargs: Any,
) -> VectorStoreQueryResult:
"""Get nodes for response."""
query_embedding = cast(List[float], query.query_embedding)
doc_embeddings = [n.embedding for n in self.node_dict.values()]
doc_ids = [n.node_id for n in self.node_dict.values()]
similarities, node_ids = get_top_k_embeddings(
query_embedding,
doc_embeddings,
doc_ids,
similarity_top_k=query.similarity_top_k,
)
result_nodes = [self.node_dict[node_id] for node_id in node_ids]
return VectorStoreQueryResult(
nodes=result_nodes, similarities=similarities, ids=node_ids
)
3.b. 支持元数据过滤¶
接下来的扩展是添加元数据过滤支持。这意味着我们将首先筛选出符合元数据过滤条件的候选文档集合,然后再执行语义查询。
为简化实现,我们采用精确匹配的元数据过滤器,并通过AND逻辑条件进行组合。
from llama_index.core.vector_stores import MetadataFilters
from llama_index.core.schema import BaseNode
from typing import cast
def filter_nodes(nodes: List[BaseNode], filters: MetadataFilters):
filtered_nodes = []
for node in nodes:
matches = True
for f in filters.filters:
if f.key not in node.metadata:
matches = False
continue
if f.value != node.metadata[f.key]:
matches = False
continue
if matches:
filtered_nodes.append(node)
return filtered_nodes
我们在执行语义搜索之前,会先通过 filter_nodes 对节点进行初步筛选。
def dense_search(query: VectorStoreQuery, nodes: List[BaseNode]):
"""Dense search."""
query_embedding = cast(List[float], query.query_embedding)
doc_embeddings = [n.embedding for n in nodes]
doc_ids = [n.node_id for n in nodes]
return get_top_k_embeddings(
query_embedding,
doc_embeddings,
doc_ids,
similarity_top_k=query.similarity_top_k,
)
class VectorStore3B(VectorStore2):
"""Implements Metadata Filtering."""
def query(
self,
query: VectorStoreQuery,
**kwargs: Any,
) -> VectorStoreQueryResult:
"""Get nodes for response."""
# 1. First filter by metadata
nodes = self.node_dict.values()
if query.filters is not None:
nodes = filter_nodes(nodes, query.filters)
if len(nodes) == 0:
result_nodes = []
similarities = []
node_ids = []
else:
# 2. Then perform semantic search
similarities, node_ids = dense_search(query, nodes)
result_nodes = [self.node_dict[node_id] for node_id in node_ids]
return VectorStoreQueryResult(
nodes=result_nodes, similarities=similarities, ids=node_ids
)
4. 将数据加载至向量存储库¶
接下来我们将文本片段加载到向量存储库中,并针对不同类型的查询运行操作:包括密集搜索、带元数据过滤的搜索等多种方式。
vector_store = VectorStore3B()
# load data into the vector stores
vector_store.add(nodes)
定义一个示例问题并嵌入它。
query_str = "Can you tell me about the key concepts for safety finetuning"
query_embedding = embed_model.get_query_embedding(query_str)
使用密集检索查询向量存储¶
query_obj = VectorStoreQuery(
query_embedding=query_embedding, similarity_top_k=2
)
query_result = vector_store.query(query_obj)
for similarity, node in zip(query_result.similarities, query_result.nodes):
print(
"\n----------------\n"
f"[Node ID {node.node_id}] Similarity: {similarity}\n\n"
f"{node.get_content(metadata_mode='all')}"
"\n----------------\n\n"
)
---------------- [Node ID 3f74fdf4-0e2e-473e-9b07-10c51eb62794] Similarity: 0.835677131511819 total_pages: 77 file_path: ./data/llama2.pdf source: 23 Specifically, we use the following techniques in safety fine-tuning: 1. Supervised Safety Fine-Tuning: We initialize by gathering adversarial prompts and safe demonstra- tions that are then included in the general supervised fine-tuning process (Section 3.1). This teaches the model to align with our safety guidelines even before RLHF, and thus lays the foundation for high-quality human preference data annotation. 2. Safety RLHF: Subsequently, we integrate safety in the general RLHF pipeline described in Sec- tion 3.2.2. This includes training a safety-specific reward model and gathering more challenging adversarial prompts for rejection sampling style fine-tuning and PPO optimization. 3. Safety Context Distillation: Finally, we refine our RLHF pipeline with context distillation (Askell et al., 2021b). ---------------- ---------------- [Node ID 5ad5efb3-8442-4e8a-b35a-cc3a10551dc9] Similarity: 0.827877930608312 total_pages: 77 file_path: ./data/llama2.pdf source: 23 Benchmarks give a summary view of model capabilities and behaviors that allow us to understand general patterns in the model, but they do not provide a fully comprehensive view of the impact the model may have on people or real-world outcomes; that would require study of end-to-end product deployments. Further testing and mitigation should be done to understand bias and other social issues for the specific context in which a system may be deployed. For this, it may be necessary to test beyond the groups available in the BOLD dataset (race, religion, and gender). As LLMs are integrated and deployed, we look forward to continuing research that will amplify their potential for positive impact on these important social issues. 4.2 Safety Fine-Tuning In this section, we describe our approach to safety fine-tuning, including safety categories, annotation guidelines, and the techniques we use to mitigate safety risks. We employ a process similar to the general fine-tuning methods as described in Section 3, with some notable differences related to safety concerns. ----------------
使用密集搜索 + 元数据过滤器查询向量存储¶
# filters = MetadataFilters(
# filters=[
# ExactMatchFilter(key="page", value=3)
# ]
# )
filters = MetadataFilters.from_dict({"source": "24"})
query_obj = VectorStoreQuery(
query_embedding=query_embedding, similarity_top_k=2, filters=filters
)
query_result = vector_store.query(query_obj)
for similarity, node in zip(query_result.similarities, query_result.nodes):
print(
"\n----------------\n"
f"[Node ID {node.node_id}] Similarity: {similarity}\n\n"
f"{node.get_content(metadata_mode='all')}"
"\n----------------\n\n"
)
---------------- [Node ID efe54bc0-4f9f-49ad-9dd5-900395a092fa] Similarity: 0.8190195580569283 total_pages: 77 file_path: ./data/llama2.pdf source: 24 4.2.2 Safety Supervised Fine-Tuning In accordance with the established guidelines from Section 4.2.1, we gather prompts and demonstrations of safe model responses from trained annotators, and use the data for supervised fine-tuning in the same manner as described in Section 3.1. An example can be found in Table 5. The annotators are instructed to initially come up with prompts that they think could potentially induce the model to exhibit unsafe behavior, i.e., perform red teaming, as defined by the guidelines. Subsequently, annotators are tasked with crafting a safe and helpful response that the model should produce. 4.2.3 Safety RLHF We observe early in the development of Llama 2-Chat that it is able to generalize from the safe demonstrations in supervised fine-tuning. The model quickly learns to write detailed safe responses, address safety concerns, explain why the topic might be sensitive, and provide additional helpful information. ---------------- ---------------- [Node ID 619c884b-cdbc-44b2-aec0-2692b44740ee] Similarity: 0.8010811332867503 total_pages: 77 file_path: ./data/llama2.pdf source: 24 In particular, when the model outputs safe responses, they are often more detailed than what the average annotator writes. Therefore, after gathering only a few thousand supervised demonstrations, we switched entirely to RLHF to teach the model how to write more nuanced responses. Comprehensive tuning with RLHF has the added benefit that it may make the model more robust to jailbreak attempts (Bai et al., 2022a). We conduct RLHF by first collecting human preference data for safety similar to Section 3.2.2: annotators write a prompt that they believe can elicit unsafe behavior, and then compare multiple model responses to the prompts, selecting the response that is safest according to a set of guidelines. We then use the human preference data to train a safety reward model (see Section 3.2.2), and also reuse the adversarial prompts to sample from the model during the RLHF stage. Better Long-Tail Safety Robustness without Hurting Helpfulness Safety is inherently a long-tail problem, where the challenge comes from a small number of very specific cases. ----------------
构建基于向量存储的 RAG 系统¶
现在我们已经构建好 RAG 系统,是时候将其接入下游系统了!
from llama_index.core import VectorStoreIndex
index = VectorStoreIndex.from_vector_store(vector_store)
query_engine = index.as_query_engine()
query_str = "Can you tell me about the key concepts for safety finetuning"
response = query_engine.query(query_str)
print(str(response))
The key concepts for safety fine-tuning include supervised safety fine-tuning, safety RLHF (Reinforcement Learning from Human Feedback), and safety context distillation. Supervised safety fine-tuning involves gathering adversarial prompts and safe demonstrations to align the model with safety guidelines before RLHF. Safety RLHF integrates safety into the RLHF pipeline by training a safety-specific reward model and gathering more challenging adversarial prompts for fine-tuning and optimization. Finally, safety context distillation is used to refine the RLHF pipeline. These techniques aim to mitigate safety risks and ensure that the model aligns with safety guidelines.
总结¶
就是这样!我们已经构建了一个简单的内存向量存储,支持基础的插入、获取、删除操作,同时支持密集向量搜索和元数据过滤功能。该存储模块可直接接入LlamaIndex的其他抽象层。
虽然目前尚未支持稀疏搜索功能,且显然不适合用于实际应用程序,但这个示例揭示了底层实现的部分核心机制!