[Beta] 基于 PGVector 的文本转 SQL 功能¶
本笔记本演示展示了如何使用 pgvector 实现文本转 SQL 功能。这使得我们能够完全在 SQL 环境中同时执行语义搜索和结构化查询!
理论上,这比单纯的语义搜索加元数据过滤能实现更具表现力的查询。
注意:此为测试版功能,接口可能会发生变化。但希望现阶段对您有所帮助!
重要提示:任何文本转 SQL 应用都应注意,执行任意 SQL 查询可能存在安全风险。建议根据实际情况采取预防措施,例如使用受限角色、只读数据库、沙箱环境等。
数据设置¶
加载文档¶
加载 Lyft 公司 2021 年度 10-K 文件。
In [ ]:
Copied!
%pip install llama-index-embeddings-huggingface
%pip install llama-index-readers-file
%pip install llama-index-llms-openai
%pip install llama-index-embeddings-huggingface
%pip install llama-index-readers-file
%pip install llama-index-llms-openai
In [ ]:
Copied!
from llama_index.readers.file import PDFReader
from llama_index.readers.file import PDFReader
In [ ]:
Copied!
reader = PDFReader()
reader = PDFReader()
下载数据
In [ ]:
Copied!
!mkdir -p 'data/10k/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'
!mkdir -p 'data/10k/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'
In [ ]:
Copied!
docs = reader.load_data("./data/10k/lyft_2021.pdf")
docs = reader.load_data("./data/10k/lyft_2021.pdf")
In [ ]:
Copied!
from llama_index.core.node_parser import SentenceSplitter
node_parser = SentenceSplitter()
nodes = node_parser.get_nodes_from_documents(docs)
from llama_index.core.node_parser import SentenceSplitter
node_parser = SentenceSplitter()
nodes = node_parser.get_nodes_from_documents(docs)
In [ ]:
Copied!
print(nodes[8].get_content(metadata_mode="all"))
print(nodes[8].get_content(metadata_mode="all"))
向 Postgres + PGVector 中插入数据¶
请确保已安装所有必要的依赖项!
In [ ]:
Copied!
!pip install psycopg2-binary pgvector asyncpg "sqlalchemy[asyncio]" greenlet
!pip install psycopg2-binary pgvector asyncpg "sqlalchemy[asyncio]" greenlet
In [ ]:
Copied!
from pgvector.sqlalchemy import Vector
from sqlalchemy import insert, create_engine, String, text, Integer
from sqlalchemy.orm import declarative_base, mapped_column
from pgvector.sqlalchemy import Vector
from sqlalchemy import insert, create_engine, String, text, Integer
from sqlalchemy.orm import declarative_base, mapped_column
建立连接¶
In [ ]:
Copied!
engine = create_engine("postgresql+psycopg2://localhost/postgres")
with engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
engine = create_engine("postgresql+psycopg2://localhost/postgres")
with engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
定义表结构¶
使用 Python 类进行定义。注意我们会存储 page_label(页面标签)、embedding(嵌入向量)和 text(文本内容)。
In [ ]:
Copied!
Base = declarative_base()
class SECTextChunk(Base):
__tablename__ = "sec_text_chunk"
id = mapped_column(Integer, primary_key=True)
page_label = mapped_column(Integer)
file_name = mapped_column(String)
text = mapped_column(String)
embedding = mapped_column(Vector(384))
Base = declarative_base()
class SECTextChunk(Base):
__tablename__ = "sec_text_chunk"
id = mapped_column(Integer, primary_key=True)
page_label = mapped_column(Integer)
file_name = mapped_column(String)
text = mapped_column(String)
embedding = mapped_column(Vector(384))
In [ ]:
Copied!
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
使用 sentence_transformers 模型为每个节点生成嵌入向量¶
In [ ]:
Copied!
# get embeddings for each row
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")
for node in nodes:
text_embedding = embed_model.get_text_embedding(node.get_content())
node.embedding = text_embedding
# get embeddings for each row
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")
for node in nodes:
text_embedding = embed_model.get_text_embedding(node.get_content())
node.embedding = text_embedding
插入数据库¶
In [ ]:
Copied!
# insert into database
for node in nodes:
row_dict = {
"text": node.get_content(),
"embedding": node.embedding,
**node.metadata,
}
stmt = insert(SECTextChunk).values(**row_dict)
with engine.connect() as connection:
cursor = connection.execute(stmt)
connection.commit()
# insert into database
for node in nodes:
row_dict = {
"text": node.get_content(),
"embedding": node.embedding,
**node.metadata,
}
stmt = insert(SECTextChunk).values(**row_dict)
with engine.connect() as connection:
cursor = connection.execute(stmt)
connection.commit()
定义 PGVectorSQLQueryEngine¶
现在我们已经将数据加载到数据库中,接下来可以开始设置查询引擎了。
定义提示词¶
我们创建了一个默认文本转SQL提示词的修改版本,用于注入对pgvector语法的认知。同时我们还提供了少量使用该语法(<-->)的示例提示。
注意:该提示词默认已包含在PGVectorSQLQueryEngine中,此处展示主要是为了提升可见性!
In [ ]:
Copied!
from llama_index.core import PromptTemplate
text_to_sql_tmpl = """\
Given an input question, first create a syntactically correct {dialect} \
query to run, then look at the results of the query and return the answer. \
You can order the results by a relevant column to return the most \
interesting examples in the database.
Pay attention to use only the column names that you can see in the schema \
description. Be careful to not query for columns that do not exist. \
Pay attention to which column is in which table. Also, qualify column names \
with the table name when needed.
IMPORTANT NOTE: you can use specialized pgvector syntax (`<->`) to do nearest \
neighbors/semantic search to a given vector from an embeddings column in the table. \
The embeddings value for a given row typically represents the semantic meaning of that row. \
The vector represents an embedding representation \
of the question, given below. Do NOT fill in the vector values directly, but rather specify a \
`[query_vector]` placeholder. For instance, some select statement examples below \
(the name of the embeddings column is `embedding`):
SELECT * FROM items ORDER BY embedding <-> '[query_vector]' LIMIT 5;
SELECT * FROM items WHERE id != 1 ORDER BY embedding <-> (SELECT embedding FROM items WHERE id = 1) LIMIT 5;
SELECT * FROM items WHERE embedding <-> '[query_vector]' < 5;
You are required to use the following format, \
each taking one line:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use tables listed below.
{schema}
Question: {query_str}
SQLQuery: \
"""
text_to_sql_prompt = PromptTemplate(text_to_sql_tmpl)
from llama_index.core import PromptTemplate
text_to_sql_tmpl = """\
Given an input question, first create a syntactically correct {dialect} \
query to run, then look at the results of the query and return the answer. \
You can order the results by a relevant column to return the most \
interesting examples in the database.
Pay attention to use only the column names that you can see in the schema \
description. Be careful to not query for columns that do not exist. \
Pay attention to which column is in which table. Also, qualify column names \
with the table name when needed.
IMPORTANT NOTE: you can use specialized pgvector syntax (`<->`) to do nearest \
neighbors/semantic search to a given vector from an embeddings column in the table. \
The embeddings value for a given row typically represents the semantic meaning of that row. \
The vector represents an embedding representation \
of the question, given below. Do NOT fill in the vector values directly, but rather specify a \
`[query_vector]` placeholder. For instance, some select statement examples below \
(the name of the embeddings column is `embedding`):
SELECT * FROM items ORDER BY embedding <-> '[query_vector]' LIMIT 5;
SELECT * FROM items WHERE id != 1 ORDER BY embedding <-> (SELECT embedding FROM items WHERE id = 1) LIMIT 5;
SELECT * FROM items WHERE embedding <-> '[query_vector]' < 5;
You are required to use the following format, \
each taking one line:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use tables listed below.
{schema}
Question: {query_str}
SQLQuery: \
"""
text_to_sql_prompt = PromptTemplate(text_to_sql_tmpl)
配置大语言模型、嵌入模型及其他组件¶
除了大语言模型和嵌入模型外,请注意我们还为数据表本身添加了注释。这能更好地帮助大语言模型理解列结构(例如通过说明嵌入列所代表的含义),从而更高效地执行表格查询或语义搜索。
In [ ]:
Copied!
from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI
from llama_index.core.query_engine import PGVectorSQLQueryEngine
from llama_index.core import Settings
sql_database = SQLDatabase(engine, include_tables=["sec_text_chunk"])
Settings.llm = OpenAI(model="gpt-4")
Settings.embed_model = embed_model
table_desc = """\
This table represents text chunks from an SEC filing. Each row contains the following columns:
id: id of row
page_label: page number
file_name: top-level file name
text: all text chunk is here
embedding: the embeddings representing the text chunk
For most queries you should perform semantic search against the `embedding` column values, since \
that encodes the meaning of the text.
"""
context_query_kwargs = {"sec_text_chunk": table_desc}
from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI
from llama_index.core.query_engine import PGVectorSQLQueryEngine
from llama_index.core import Settings
sql_database = SQLDatabase(engine, include_tables=["sec_text_chunk"])
Settings.llm = OpenAI(model="gpt-4")
Settings.embed_model = embed_model
table_desc = """\
This table represents text chunks from an SEC filing. Each row contains the following columns:
id: id of row
page_label: page number
file_name: top-level file name
text: all text chunk is here
embedding: the embeddings representing the text chunk
For most queries you should perform semantic search against the `embedding` column values, since \
that encodes the meaning of the text.
"""
context_query_kwargs = {"sec_text_chunk": table_desc}
定义查询引擎¶
In [ ]:
Copied!
query_engine = PGVectorSQLQueryEngine(
sql_database=sql_database,
text_to_sql_prompt=text_to_sql_prompt,
context_query_kwargs=context_query_kwargs,
)
query_engine = PGVectorSQLQueryEngine(
sql_database=sql_database,
text_to_sql_prompt=text_to_sql_prompt,
context_query_kwargs=context_query_kwargs,
)
执行查询操作¶
现在我们已经准备好执行一些查询了
In [ ]:
Copied!
response = query_engine.query(
"Can you tell me about the risk factors described in page 6?",
)
response = query_engine.query(
"Can you tell me about the risk factors described in page 6?",
)
In [ ]:
Copied!
print(str(response))
print(str(response))
Page 6 discusses the impact of the COVID-19 pandemic on the business. It mentions that the pandemic has affected communities in the United States, Canada, and globally. The pandemic has led to a significant decrease in the demand for ridesharing services, which has negatively impacted the company's financial performance. The page also discusses the company's efforts to adapt to the changing environment by focusing on the delivery of essential goods and services. Additionally, it mentions the company's transportation network, which offers riders seamless, personalized, and on-demand access to a variety of mobility options.
In [ ]:
Copied!
print(response.metadata["sql_query"])
print(response.metadata["sql_query"])
In [ ]:
Copied!
response = query_engine.query(
"Tell me more about Lyft's real estate operating leases",
)
response = query_engine.query(
"Tell me more about Lyft's real estate operating leases",
)
In [ ]:
Copied!
print(str(response))
print(str(response))
Lyft's lease arrangements include vehicle rental programs, office space, and data centers. Leases that do not meet any specific criteria are accounted for as operating leases. The lease term begins when Lyft is available to use the underlying asset and ends upon the termination of the lease. The lease term includes any periods covered by an option to extend if Lyft is reasonably certain to exercise that option. Leasehold improvements are amortized on a straight-line basis over the shorter of the term of the lease, or the useful life of the assets.
In [ ]:
Copied!
print(response.metadata["sql_query"][:300])
print(response.metadata["sql_query"][:300])
SELECT * FROM sec_text_chunk WHERE text LIKE '%Lyft%' AND text LIKE '%real estate%' AND text LIKE '%operating leases%' ORDER BY embedding <-> '[-0.007079003844410181, -0.04383348673582077, 0.02910166047513485, 0.02049737051129341, 0.009460929781198502, -0.017539210617542267, 0.04225028306245804, 0.0
In [ ]:
Copied!
# looked at returned result
print(response.metadata["result"])
# looked at returned result
print(response.metadata["result"])
[(157, 93, 'lyft_2021.pdf', "Leases that do not meet any of the above criteria are accounted for as operating leases.Lessor\nThe\n Company's lease arrangements include vehicle re ... (4356 characters truncated) ... realized. Leasehold improvements are amortized on a straight-line basis over the shorter of the term of the lease, or the useful life of the assets.", '[0.017818017,-0.024016099,0.0042511695,0.03114478,0.003591422,-0.0097886855,0.02455732,0.013048866,0.018157514,-0.009401044,0.031699456,0.01678178,0. ... (4472 characters truncated) ... 6,0.01127416,0.045080125,-0.017046565,-0.028544193,-0.016320521,0.01062995,-0.021007432,-0.006999497,-0.08426073,-0.014918887,0.059064835,0.03307945]')]
In [ ]:
Copied!
# structured query
response = query_engine.query(
"Tell me about the max page number in this table",
)
# structured query
response = query_engine.query(
"Tell me about the max page number in this table",
)
In [ ]:
Copied!
print(str(response))
print(str(response))
The maximum page number in this table is 238.
In [ ]:
Copied!
print(response.metadata["sql_query"][:300])
print(response.metadata["sql_query"][:300])
SELECT MAX(page_label) FROM sec_text_chunk;