文本转SQL指南(查询引擎+检索器)¶
本文是关于LlamaIndex文本转SQL功能的基础指南。
- 首先演示如何在玩具数据集上执行文本转SQL:该过程将完成"检索"(对数据库执行SQL查询)和"合成"操作
- 接着展示如何构建表索引(TableIndex),以便在查询时动态检索相关表结构
- 然后说明如何使用查询时的行/列检索器来增强文本转SQL的上下文理解
- 最后演示如何单独定义文本转SQL检索器
注意: 任何文本转SQL应用都应注意,执行任意SQL查询可能存在安全风险。建议根据需要采取预防措施,例如使用受限角色、只读数据库、沙箱环境等。
如果您在 Colab 上打开此 Notebook,可能需要安装 LlamaIndex 🦙。
%pip install llama-index-core llama-index-llms-openai llama-index-embeddings-openai
import os
import openai
os.environ["OPENAI_API_KEY"] = "sk-.."
# import logging
# import sys
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
from IPython.display import Markdown, display
创建数据库模式¶
我们使用流行的 SQL 数据库工具包 sqlalchemy 来创建一个空的 city_stats 表
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
select,
)
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)
定义 SQL 数据库¶
我们首先定义 SQLDatabase 抽象层(一个基于 SQLAlchemy 的轻量封装)。
from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI
llm = OpenAI(temperature=0.1, model="gpt-4.1-mini")
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
我们向 SQL 数据库添加了一些测试数据。
from sqlalchemy import insert
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
rows = [
{"city_name": "Toronto", "population": 2930000, "country": "Canada"},
{"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
{
"city_name": "Chicago",
"population": 2679000,
"country": "United States",
},
{
"city_name": "New York",
"population": 8258000,
"country": "United States",
},
{"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
{"city_name": "Busan", "population": 3334000, "country": "South Korea"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.begin() as connection:
cursor = connection.execute(stmt)
# view current table
stmt = select(
city_stats_table.c.city_name,
city_stats_table.c.population,
city_stats_table.c.country,
).select_from(city_stats_table)
with engine.connect() as connection:
results = connection.execute(stmt).fetchall()
print(results)
[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('New York', 8258000, 'United States'), ('Seoul', 9776000, 'South Korea'), ('Busan', 3334000, 'South Korea')]
查询索引¶
我们首先演示如何执行原始 SQL 查询,该查询会直接在数据表上执行。
from sqlalchemy import text
with engine.connect() as con:
rows = con.execute(text("SELECT city_name from city_stats"))
for row in rows:
print(row)
('Busan',)
('Chicago',)
('New York',)
('Seoul',)
('Tokyo',)
('Toronto',)
第一部分:文本转SQL查询引擎¶
构建完SQL数据库后,我们可以使用NLSQLTableQueryEngine来构造自然语言查询,这些查询会被合成为SQL语句。
请注意,使用该查询引擎时需要指定目标表。若未指定,引擎将拉取所有表结构上下文,可能导致大语言模型(LLM)的上下文窗口溢出。
from llama_index.core.query_engine import NLSQLTableQueryEngine
query_engine = NLSQLTableQueryEngine(
sql_database=sql_database, tables=["city_stats"], llm=llm
)
query_str = "Which city has the highest population?"
response = query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))
Tokyo has the highest population among all cities, with a population of 13,960,000.
在以下任何情况下都应使用此查询引擎:能够预先指定要查询的表,或者所有表结构加上提示其余部分的总大小适合您的上下文窗口。
第二部分:文本转SQL场景下的表结构查询时检索¶
若我们无法预先确定需要使用哪些表,且所有表结构的总大小超出了上下文窗口的容量限制,就应当将表结构存储在索引中,以便在查询时能够检索到正确的表结构。
具体实现方式是使用 SQLTableNodeMapping 对象,该对象接收一个 SQLDatabase 实例,并为传入 ObjectIndex 构造函数的每个 SQLTableSchema 对象生成对应的 Node 对象。
from llama_index.core.indices.struct_store.sql_query import (
SQLTableRetrieverQueryEngine,
)
from llama_index.core.objects import (
SQLTableNodeMapping,
ObjectIndex,
SQLTableSchema,
)
from llama_index.core import VectorStoreIndex
from llama_index.core.embeddings.openai import OpenAIEmbedding
# set Logging to DEBUG for more detailed outputs
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
(SQLTableSchema(table_name="city_stats"))
] # add a SQLTableSchema for each table
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
embed_model=OpenAIEmbedding(model="text-embedding-3-small"),
)
query_engine = SQLTableRetrieverQueryEngine(
sql_database, obj_index.as_retriever(similarity_top_k=1)
)
现在我们可以使用 SQLTableRetrieverQueryEngine 来查询获取响应。
response = query_engine.query("Which city has the highest population?")
display(Markdown(f"<b>{response}</b>"))
Tokyo has the highest population among all cities, with a population of 13,960,000.
# you can also fetch the raw result from SQLAlchemy!
response.metadata["result"]
[('Tokyo', 13960000)]
您还可以为每个定义的表结构添加额外的上下文信息。
# manually set context text
city_stats_text = (
"This table gives information regarding the population and country of a"
" given city.\nThe user will query with codewords, where 'foo' corresponds"
" to population and 'bar'corresponds to city."
)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
(SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
]
第三部分:文本到SQL的行列查询时检索¶
当提出诸如"美国有多少个城市?"这类问题时,会面临一个挑战:生成的查询可能仅搜索国家字段标记为"US"的城市,而可能遗漏标记为"United States"的条目。为解决这个问题,可以采用查询时行检索、查询时列检索或两者结合的方式。
查询时行检索¶
在查询时行检索方法中,我们会为每个表格的行生成嵌入向量,从而为每个表格创建一个独立的索引。
from llama_index.core.schema import TextNode
with engine.connect() as connection:
results = connection.execute(stmt).fetchall()
city_nodes = [TextNode(text=str(t)) for t in results]
city_rows_index = VectorStoreIndex(
city_nodes, embed_model=OpenAIEmbedding(model="text-embedding-3-small")
)
city_rows_retriever = city_rows_index.as_retriever(similarity_top_k=1)
city_rows_retriever.retrieve("US")
[NodeWithScore(node=TextNode(id_='8ae10176-afd8-40ee-a97b-b24f66235489', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, metadata_template='{key}: {value}', metadata_separator='\n', text="('Chicago', 2679000, 'United States')", mimetype='text/plain', start_char_idx=None, end_char_idx=None, metadata_seperator='\n', text_template='{metadata_str}\n\n{content}'), score=0.7843469586763699)]
随后,可以将每个表格的行检索器提供给 SQLTableRetrieverQueryEngine。
rows_retrievers = {
"city_stats": city_rows_retriever,
}
query_engine = SQLTableRetrieverQueryEngine(
sql_database,
obj_index.as_retriever(similarity_top_k=1),
rows_retrievers=rows_retrievers,
)
在查询过程中,行检索器用于识别与输入查询语义最相似的行。这些检索到的行随后会作为上下文融入,以提升文本到SQL生成的性能。
response = query_engine.query("How many cities are in the US?")
display(Markdown(f"<b>{response}</b>"))
There are 2 cities in the United States according to the data in the city_stats table.
查询时列检索¶
虽然查询时行检索增强了文本到SQL的生成能力,但它会单独嵌入每一行数据——即使许多行包含重复值(例如分类数据中的重复值)。这可能导致令牌使用效率低下并产生不必要的开销。此外,在包含大量列的表中,检索器可能仅返回部分相关值,可能会遗漏其他对生成准确查询至关重要的数据。
为解决这一问题,可采用查询时列检索方法。该方法为选定列中的每个唯一值建立索引,从而为表中的每一列创建独立的索引。
city_cols_retrievers = {}
for column_name in ["city_name", "country"]:
stmt = select(city_stats_table.c[column_name]).distinct()
with engine.connect() as connection:
values = connection.execute(stmt).fetchall()
nodes = [TextNode(text=t[0]) for t in values]
column_index = VectorStoreIndex(
nodes, embed_model=OpenAIEmbedding(model="text-embedding-3-small")
)
column_retriever = column_index.as_retriever(similarity_top_k=1)
city_cols_retrievers[column_name] = column_retriever
随后,可以将每个表的列检索器提供给 SQLTableRetrieverQueryEngine。
cols_retrievers = {
"city_stats": city_cols_retrievers,
}
query_engine = SQLTableRetrieverQueryEngine(
sql_database,
obj_index.as_retriever(similarity_top_k=1),
rows_retrievers=rows_retrievers,
cols_retrievers=cols_retrievers,
llm=llm,
)
在查询过程中,列检索器用于识别与输入查询语义最相似的列值。这些检索到的值随后会作为上下文融入,以提升文本到SQL生成的性能。
response = query_engine.query("How many cities are in the US?")
display(Markdown(f"<b>{response}</b>"))
There are 2 cities in the United States.
第四部分:文本转SQL检索器¶
目前我们的文本转SQL功能封装在查询引擎中,包含检索与合成两个部分。
您可以单独使用SQL检索器。我们将展示几种可尝试的参数配置,并演示如何将其接入我们的RetrieverQueryEngine以获得基本一致的结果。
from llama_index.core.retrievers import NLSQLRetriever
# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
sql_database, tables=["city_stats"], llm=llm, return_raw=True
)
results = nl_sql_retriever.retrieve(
"Return the top 5 cities (along with their populations) with the highest population."
)
from llama_index.core.response.notebook_utils import display_source_node
for n in results:
display_source_node(n)
Node ID: f640a54f-7413-4dc0-9135-cd63c7ca8f45
Similarity: None
Text: [('Tokyo', 13960000), ('Seoul', 9776000), ('New York', 8258000), ('Busan', 3334000), ('Toronto', ...
# default retrieval (return_raw=False)
nl_sql_retriever = NLSQLRetriever(
sql_database, tables=["city_stats"], return_raw=False
)
results = nl_sql_retriever.retrieve(
"Return the top 5 cities (along with their populations) with the highest population."
)
# NOTE: all the content is in the metadata
for n in results:
display_source_node(n, show_source_metadata=True)
Node ID: 05c61a90-598e-4c29-a6b4-b27f2579819e
Similarity: None
Text:
Metadata: {'city_name': 'Tokyo', 'population': 13960000}
Node ID: c7f5fc4c-9754-4946-92c6-54a0d2b40fd9
Similarity: None
Text:
Metadata: {'city_name': 'Seoul', 'population': 9776000}
Node ID: 3a00e201-f3b5-430e-af0e-aa4c34a71131
Similarity: None
Text:
Metadata: {'city_name': 'New York', 'population': 8258000}
Node ID: ee911f7f-8aae-4bad-a52d-c0bdfab63942
Similarity: None
Text:
Metadata: {'city_name': 'Busan', 'population': 3334000}
Node ID: dca6b482-52e4-41e0-992f-a58109e6f3f6
Similarity: None
Text:
Metadata: {'city_name': 'Toronto', 'population': 2930000}
接入我们的 RetrieverQueryEngine¶
我们将 SQL 检索器与标准 RetrieverQueryEngine 组合使用,以合成响应结果。其效果与我们封装的 Text-to-SQL 查询引擎大致相似。
from llama_index.core.query_engine import RetrieverQueryEngine
query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever, llm=llm)
response = query_engine.query(
"Return the top 5 cities (along with their populations) with the highest population."
)
print(str(response))
The top 5 cities with the highest populations are: 1. Tokyo - 13,960,000 2. Seoul - 9,776,000 3. New York - 8,258,000 4. Busan - 3,334,000 5. Toronto - 2,930,000