从自然语言到 SQL:构建和追踪多语言查询引擎
如果您正在寻找一个多语言查询引擎,该引擎可以将自然语言转换为 SQL 生成与查询执行相结合,并充分利用 MLflow 的功能,那么这篇博文就是您的指南。我们将探讨如何利用 MLflow 的 **代码模型 (Models from Code)** 功能,实现对 AI 工作流的无缝跟踪和版本控制。此外,我们将深入研究 **MLflow 的跟踪 (Tracing)** 功能,该功能通过跟踪 AI 工作流的每个中间步骤的输入、输出和元数据,为 AI 工作流的众多不同组件引入了可观测性。
简介
SQL 是管理和访问关系数据库中数据的基本技能。然而,构建复杂的 SQL 查询来回答棘手的数据问题可能既困难又耗时。这种复杂性可能会阻碍有效利用数据的能力。自然语言到 SQL (NL2SQL) 系统通过提供从自然语言到 SQL 命令的转换来帮助解决这个问题,使得非技术人员能够与数据进行交互:用户只需用他们熟悉的自然语言提问,这些系统就会帮助他们返回相应的信息。
然而,在创建 NL2SQL 系统时,仍然存在语义歧义、模式映射或错误处理和用户反馈等一系列问题。因此,在构建此类系统时,我们必须设置一些防护措施,而不是完全依赖 LLM。
在这篇博文中,我们将引导您完成构建多语言查询引擎的过程。该引擎支持多种语言的自然语言输入,根据翻译后的用户输入生成 SQL 查询,并执行该查询。让我们来看一个例子:使用包含公司客户、产品和订单信息的数据库,用户可以用任何语言提出问题,例如“Quantos clientes temos por país?”(葡萄牙语,意为“我们有多少客户按国家/地区划分?”)。AI 工作流将输入翻译成英语,输出“How many customers do we have per country?”。然后,它会验证输入的安全性,检查是否可以使用数据库模式回答问题,生成相应的 SQL 查询(例如,SELECT COUNT(CustomerID) AS NumberOfCustomers, Country FROM Customers GROUP BY Country;),并验证查询以确保不包含任何有害命令(例如,DROP)。最后,它针对数据库执行查询以检索结果。
我们将从演示如何利用 LangGraph 的功能来构建动态 AI 工作流开始。此工作流集成了 OpenAI 和外部数据源,例如 Vector Store 和 SQLite 数据库,以处理用户输入、执行安全检查、查询数据库以及生成有意义的响应。
在本文中,我们将利用 **MLflow 的代码模型 (Models from Code)** 功能,实现对 AI 工作流的无缝跟踪和版本控制。此外,我们将深入研究 **MLflow 的跟踪 (Tracing)** 功能,该功能旨在通过跟踪 AI 工作流的每个中间步骤的输入、输出和相关元数据,增强 AI 工作流众多不同组件的可观测性。这使得轻松识别错误和意外行为成为可能,从而提高工作流的透明度。
先决条件
要设置和运行此项目,请确保已安装以下 **Python 包**
faiss-cpulangchainlangchain-corelangchain-openailanggraphlangchain-communitypydantic >=2typing_extensionspython-dotenv
此外,还需要 **MLflow 跟踪服务器 (MLflow Tracking Server)** 来有效记录和管理实验、模型和跟踪。有关本地设置,请参阅官方 MLflow 文档,了解有关 **配置简单的 MLflow 跟踪服务器** 的说明。
最后,请确保您的 OpenAI API 密钥已保存在项目目录的 .env 文件中。这允许应用程序安全地访问构建 AI 工作流所需的 OpenAI 服务。 .env 文件应包含类似以下内容的行:
OPENAI_API_KEY=your_openai_api_key
使用 LangGraph 的多语言查询引擎
多语言查询引擎利用 **LangGraph** 库,这是一个 AI 编排工具,用于为由 LLM 驱动的应用程序创建有状态、多代理和循环图架构。
与其他 AI 编排器相比,LangGraph 提供了三个核心优势:循环、可控性和持久性。它允许定义带有循环的 AI 工作流,这对于实现重试机制至关重要,例如多语言查询引擎中的 SQL 查询生成重试(如果验证失败,查询会回退以重新生成)。这使得 LangGraph 成为构建我们的多语言查询引擎的理想工具。
LangGraph 的关键功能:
-
有状态架构:引擎维护着图执行状态的动态快照。此快照充当节点之间的共享资源,从而在每次节点执行时实现高效的决策和实时更新。
-
多代理设计:AI 工作流在整个工作流中包含与 OpenAI 和其他外部工具的多次交互。
-
循环图结构:图的循环性质引入了强大的重试机制。当需要时,该机制通过循环回先前阶段来动态处理故障,从而确保图的连续执行。(此机制的细节将在后面讨论。)
AI 工作流概述
多语言查询引擎的高级 AI 工作流由相互连接的节点和边组成,每个节点和边代表一个关键阶段。
-
翻译节点:将用户的输入翻译成英语。
-
预安全检查:确保用户输入不包含有毒或不当内容,并且不包含有害的 SQL 命令(例如,
DELETE、DROP)。 -
数据库模式提取:检索目标数据库的模式,以了解其结构和可用数据。
-
相关性验证:将用户输入与数据库模式进行比较,以确保与数据库的上下文一致。
-
SQL 查询生成:根据用户输入和当前数据库模式生成 SQL 查询。
-
后安全检查:确保生成的 SQL 查询不包含有害的 SQL 命令(例如,
DELETE、DROP)。 -
SQL 查询验证:在回滚安全的(rollback-safe)环境中执行 SQL 查询,以在运行前确保其有效性。
-
动态状态评估:根据当前状态确定下一步操作。如果 SQL 查询验证失败,则循环回到第 5 阶段以重新生成查询。
-
查询执行和结果检索:执行 SQL 查询并返回结果(如果是
SELECT语句)。
重试机制在第 8 阶段引入,其中系统动态评估当前图状态。具体来说,当 SQL 查询验证节点(第 7 阶段)检测到问题时,状态会触发循环回到 SQL 生成节点(第 5 阶段)以进行新的 SQL 生成尝试(最多 3 次尝试)。
组件
多语言查询引擎与多个外部组件交互,以将自然语言用户输入转换为 SQL 查询,并以安全可靠的方式执行它们。在本节中,我们将详细介绍关键的 AI 工作流组件:OpenAI、Vector Store、SQLite 数据库和 SQL 生成链。
OpenAI
OpenAI,更具体地说,是 gpt-4o-mini 语言模型,在工作流的多个阶段起着至关重要的作用。它提供了执行以下操作所需的智能:
-
翻译:将用户输入翻译成英语。如果文本已经是英语,它将简单地重复输入。
-
安全检查:分析用户输入,确保其不包含有毒或不当内容。
-
相关性检查:评估用户的问题是否与数据库模式相关。
-
SQL 生成:根据用户输入、SQL 生成文档和数据库模式生成有效且可执行的 SQL 查询。
有关 OpenAI 实现的详细信息将在后面的 节点描述 部分提供。
FAISS Vector Store
为了构建一个有效的自然语言到 SQL 引擎,能够生成准确且可执行的 SQL 查询,我们利用 Langchain 的 FAISS Vector Store 功能。此设置允许系统从之前存储在 Vector 数据库中的 **W3Schools SQL 文档** 中搜索和提取 SQL 查询生成指南,从而提高 SQL 查询生成的成功率。
出于演示目的,我们使用的是 FAISS,这是一个内存中的向量存储,向量直接存储在 RAM 中。这提供了快速访问,但意味着数据在运行之间不会持久化。对于一个更具可扩展性的解决方案,允许跨多个项目存储和共享嵌入,我们建议使用 AWS OpenSearch、Vertex AI Vector Search、Azure Vector Search 或 Mosaic AI Vector Search 等替代方案。这些基于云的解决方案提供持久存储、自动缩放以及与其他云服务的无缝集成,非常适合大规模应用程序。
步骤 1:加载 SQL 文档
创建带有 SQL 查询生成指南的 FAISS Vector Store 的第一步是使用 LangChain 的 `RecursiveUrlLoader` 从 **W3Schools SQL 页面** 加载 SQL 文档。此工具检索文档,允许我们将其用作引擎的知识库。
步骤 2:将文本分割成可管理的块
加载的 SQL 文档是冗长的文本,很难被 LLM 有效地摄取。为解决此问题,下一步是使用 Langchain 的 `RecursiveCharacterTextSplitter` 将文本分割成更小、更易于管理的块。通过将文本分割成 500 个字符的块,并带有 50 个字符的重叠,我们确保语言模型具有足够的上下文,同时最大限度地降低丢失跨块重要信息的风险。`split_text` 方法应用此分割过程,将生成的片段存储在名为“documents”的列表中。
步骤 3:生成嵌入模型
第三步是创建一个模型,该模型将这些块转换为嵌入(每个文本块的矢量化数值表示)。嵌入使系统能够比较块与用户输入之间的相似性,从而便于检索与 SQL 查询生成最相关的匹配项。
步骤 4:创建嵌入并将它们存储在 FAISS Vector Store 中
最后,我们使用 FAISS 创建并存储嵌入。`FAISS.from_texts` 方法接受所有块,计算它们的嵌入,并将它们存储在高速可搜索的向量数据库中。这个可搜索的数据库允许引擎有效地检索相关的 SQL 指南,显著提高了可执行 SQL 查询生成的成功率。
import logging
import os
from bs4 import BeautifulSoup as Soup
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
def setup_vector_store(logger: logging.Logger):
"""Setup or load the vector store."""
if not os.path.exists("data"):
os.makedirs("data")
vector_store_dir = "data/vector_store"
if os.path.exists(vector_store_dir):
# Load the vector store from disk
logger.info("Loading vector store from disk...")
vector_store = FAISS.load_local(
vector_store_dir,
OpenAIEmbeddings(),
allow_dangerous_deserialization=True,
)
else:
logger.info("Creating new vector store...")
# Load SQL documentation
url = "https://w3schools.org.cn/sql/"
loader = RecursiveUrlLoader(
url=url, max_depth=2, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()
# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
)
documents = []
for doc in docs:
splits = text_splitter.split_text(doc.page_content)
for i, split in enumerate(splits):
documents.append(
{
"content": split,
"metadata": {"source": doc.metadata["source"], "chunk": i},
}
)
# Compute embeddings and create vector store
embedding_model = OpenAIEmbeddings()
vector_store = FAISS.from_texts(
[doc["content"] for doc in documents],
embedding_model,
metadatas=[doc["metadata"] for doc in documents],
)
# Save the vector store to disk
vector_store.save_local(vector_store_dir)
logger.info("Vector store created and saved to disk.")
return vector_store
SQLite 数据库
SQLite 数据库是多语言查询引擎的关键组成部分,充当结构化数据存储库。SQLite 提供了一个轻量级、快速且自包含的关系数据库引擎,无需服务器设置或安装。其紧凑的尺寸(不到 500KB)和零配置的性质使其极其易于使用,而其平台无关的数据库格式确保了跨不同系统的无缝可移植性。作为本地磁盘数据库,SQLite 是避免设置 MySQL 或 PostgreSQL 复杂性的理想选择,同时仍提供可靠、功能齐全且性能出色的 SQL 引擎。
SQLite 数据库通过支持以下功能,支持高效的 SQL 查询生成、验证和执行:
-
模式提取:为用户输入上下文验证(第 4 阶段)和可执行 SQL 查询生成(第 5 阶段)提供模式信息。
-
查询执行:在验证阶段(第 7 阶段)和查询执行阶段(第 9 阶段)中,在回滚安全的(rollback-safe)环境中执行 SQL 查询,为
SELECT语句获取结果,并为其他查询类型提交更改。
SQLite 数据库初始化
当 AI 工作流初始化时,使用 `setup_database` 函数来初始化数据库。此过程包括:
-
设置 SQLite 数据库连接:建立到 SQLite 数据库的连接,从而实现数据交互。
-
表创建:定义并创建 AI 工作流所需的数据库表。
-
数据填充:用示例数据填充表,以支持查询执行和验证阶段。
import logging
import os
import sqlite3
def create_connection(db_file="data/database.db"):
"""Create a database connection to the SQLite database."""
conn = sqlite3.connect(db_file)
return conn
def create_tables(conn):
"""Create tables in the database."""
cursor = conn.cursor()
# Create Customers table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Customers (
CustomerID INTEGER PRIMARY KEY,
CustomerName TEXT,
ContactName TEXT,
Address TEXT,
City TEXT,
PostalCode TEXT,
Country TEXT
)
"""
)
# Create Orders table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Orders (
OrderID INTEGER PRIMARY KEY,
CustomerID INTEGER,
OrderDate TEXT,
FOREIGN KEY (CustomerID) REFERENCES Customers (CustomerID)
)
"""
)
# Create OrderDetails table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS OrderDetails (
OrderDetailID INTEGER PRIMARY KEY,
OrderID INTEGER,
ProductID INTEGER,
Quantity INTEGER,
FOREIGN KEY (OrderID) REFERENCES Orders (OrderID),
FOREIGN KEY (ProductID) REFERENCES Products (ProductID)
)
"""
)
# Create Products table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Products (
ProductID INTEGER PRIMARY KEY,
ProductName TEXT,
Price REAL
)
"""
)
conn.commit()
def populate_tables(conn):
"""Populate tables with sample data if they are empty."""
cursor = conn.cursor()
# Populate Customers table if empty
cursor.execute("SELECT COUNT(*) FROM Customers")
if cursor.fetchone()[0] == 0:
customers = []
for i in range(1, 51):
customers.append(
(
i,
f"Customer {i}",
f"Contact {i}",
f"Address {i}",
f"City {i % 10}",
f"{10000 + i}",
f"Country {i % 5}",
)
)
cursor.executemany(
"""
INSERT INTO Customers (CustomerID, CustomerName, ContactName, Address, City, PostalCode, Country)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
customers,
)
# Populate Products table if empty
cursor.execute("SELECT COUNT(*) FROM Products")
if cursor.fetchone()[0] == 0:
products = []
for i in range(1, 51):
products.append((i, f"Product {i}", round(10 + i * 0.5, 2)))
cursor.executemany(
"""
INSERT INTO Products (ProductID, ProductName, Price)
VALUES (?, ?, ?)
""",
products,
)
# Populate Orders table if empty
cursor.execute("SELECT COUNT(*) FROM Orders")
if cursor.fetchone()[0] == 0:
orders = []
from datetime import datetime, timedelta
base_date = datetime(2023, 1, 1)
for i in range(1, 51):
order_date = base_date + timedelta(days=i)
orders.append(
(
i,
i % 50 + 1, # CustomerID between 1 and 50
order_date.strftime("%Y-%m-%d"),
)
)
cursor.executemany(
"""
INSERT INTO Orders (OrderID, CustomerID, OrderDate)
VALUES (?, ?, ?)
""",
orders,
)
# Populate OrderDetails table if empty
cursor.execute("SELECT COUNT(*) FROM OrderDetails")
if cursor.fetchone()[0] == 0:
order_details = []
for i in range(1, 51):
order_details.append(
(
i,
i % 50 + 1, # OrderID between 1 and 50
i % 50 + 1, # ProductID between 1 and 50
(i % 5 + 1) * 2, # Quantity between 2 and 10
)
)
cursor.executemany(
"""
INSERT INTO OrderDetails (OrderDetailID, OrderID, ProductID, Quantity)
VALUES (?, ?, ?, ?)
""",
order_details,
)
conn.commit()
def setup_database(logger: logging.Logger):
"""Setup the database and return the connection."""
db_file = "data/database.db"
if not os.path.exists("data"):
os.makedirs("data")
db_exists = os.path.exists(db_file)
conn = create_connection(db_file)
if not db_exists:
logger.info("Setting up the database...")
create_tables(conn)
populate_tables(conn)
else:
logger.info("Database already exists. Skipping setup.")
return conn
SQL 生成链
SQL 生成链(`sql_gen_chain`)是我们工作流中自动化 SQL 查询生成的支柱。该链利用 LangChain 的模块化功能和 OpenAI 的高级自然语言处理能力,将用户问题转化为精确且可执行的 SQL 查询。
核心功能:
-
提示驱动生成:以精心设计的提示开始,该提示集成了数据库模式和文档片段,确保查询在上下文上是准确的。
-
结构化响应:以预定义的格式提供输出,包括:
-
查询目的的 **描述**。
-
相应的 **SQL 代码**,可供执行。
-
-
适应性强且可靠:使用 `gpt-4o-mini` 进行强大的、一致的查询生成,最大限度地减少手动工作和错误。
该链是我们工作流中的关键组件,可实现 SQL 查询生成与下游流程的无缝集成,确保准确性,并显著提高效率。
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
class SQLQuery(BaseModel):
"""Schema for SQL query solutions to questions."""
description: str = Field(description="Description of the SQL query")
sql_code: str = Field(description="The SQL code block")
def get_sql_gen_chain():
"""Set up the SQL generation chain."""
sql_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a SQL assistant with expertise in SQL query generation. \n
Answer the user's question based on the provided documentation snippets and the database schema provided below. Ensure any SQL query you provide is valid and executable. \n
Structure your answer with a description of the query, followed by the SQL code block. Here are the documentation snippets:\n{retrieved_docs}\n\nDatabase Schema:\n{database_schema}""",
),
("placeholder", "{messages}"),
]
)
# Initialize the OpenAI LLM
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
# Create the code generation chain
sql_gen_chain = sql_gen_prompt | llm.with_structured_output(SQLQuery)
return sql_gen_chain
工作流设置和初始化
在深入研究工作流节点之前,至关重要的是设置必要的组件并定义工作流的结构。本节解释了核心库、日志记录和自定义 `GraphState` 类的初始化,以及主工作流编译函数。
定义 `GraphState`
`GraphState` 类是一个自定义 `TypedDict`,可在工作流进行时维护状态信息。它充当节点之间的共享数据结构,确保连续性和一致性。关键字段包括:
- `error`:跟踪是否发生了错误。
- `messages`:存储用户和系统消息的列表。
- `generation`:保存生成的 SQL 查询。
- `iterations`:跟踪发生错误时的重试次数。
- `results`:存储 SQL 执行结果(如果有)。
- `no_records_found`:如果查询未返回任何记录,则标记。
- `translated_input`:包含用户翻译后的输入。
- `database_schema`:维护数据库模式以进行上下文验证。
import logging
import re
from typing import List, Optional
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from sql_generation import get_sql_gen_chain
from typing_extensions import TypedDict
# Initialize the logger
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
_logger.addHandler(handler)
class GraphState(TypedDict):
error: str # Tracks if an error has occurred
messages: List # List of messages (user input and assistant messages)
generation: Optional[str] # Holds the generated SQL query
iterations: int # Keeps track of how many times the workflow has retried
results: Optional[List] # Holds the results of SQL execution
no_records_found: bool # Flag for whether any records were found in the SQL result
translated_input: str # Holds the translated user input
database_schema: str # Holds the extracted database schema for context checking
工作流编译函数
主函数 `get_workflow` 负责定义和编译工作流。关键组件包括:
- `conn` 和 `cursor`:用于数据库连接和查询执行。
- `vector_store`:用于上下文检索的向量数据库。
- `max_iterations`:设置重试次数限制,以防止无限循环。
- `sql_gen_chain`:从 `sql_generation` 中检索 SQL 生成链,用于根据上下文输入生成 SQL 查询。
- `ChatOpenAI`:初始化 OpenAI `gpt-4o-mini` 模型,用于安全检查和查询翻译等任务。
def get_workflow(conn, cursor, vector_store):
"""Define and compile the LangGraph workflow."""
# Max iterations: defines how many times the workflow should retry in case of errors
max_iterations = 3
# SQL generation chain: this is a chain that will generate SQL based on retrieved docs
sql_gen_chain = get_sql_gen_chain()
# Initialize OpenAI LLM for translation and safety checks
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
# Define the individual nodes of the workflow
此函数作为使用 `StateGraph` 创建完整工作流的入口点。工作流中的各个节点将在后续部分中定义和连接。
节点描述
1. 翻译输入
`translate_input` 节点将用户查询翻译成英语,以标准化处理并确保与下游节点兼容。将用户输入作为 AI 工作流的第一步进行翻译,可以确保任务隔离并提高可观测性。任务隔离通过将翻译与用户输入安全验证和 SQL 生成等其他下游任务分开来简化工作流。提高的可观测性在 MLflow 中提供了清晰的跟踪,从而更易于调试和监控过程。
- 示例
- 输入:“Quantos pedidos foram realizados em Novembro?”
- 翻译:“How many orders were made in November?”
- 输入:“Combien de ventes avons-nous enregistrées en France ?”
- 翻译:“How many sales did we record in France?”
- 代码
def translate_input(state: GraphState) -> GraphState:
"""
Translates user input to English using an LLM. If the input is already in English,
it is returned as is. This ensures consistent input for downstream processing.
Args:
state (GraphState): The current graph state containing user messages.
Returns:
GraphState: The updated state with the translated input.
"""
_logger.info("Starting translation of user input to English.")
messages = state["messages"]
user_input = messages[-1][1] # Get the latest user input
# Translation prompt for the model
translation_prompt = f"""
Translate the following text to English. If the text is already in English, repeat it exactly without any additional explanation.
Text:
{user_input}
"""
# Call the OpenAI LLM to translate the text
translated_response = llm.invoke(translation_prompt)
translated_text = translated_response.content.strip() # Access the 'content' attribute and strip any extra spaces
# Update state with the translated input
state["translated_input"] = translated_text
_logger.info("Translation completed successfully. Translated input: %s", translated_text)
return state
2. 预安全检查
`pre_safety_check` 节点确保在用户输入中及早检测到不允许的 SQL 操作和不当内容。虽然在工作流的后期(在生成 SQL 查询之后)会再次进行有害 SQL 命令(例如,CREATE、DELETE、DROP、INSERT、UPDATE)的检查,但此预安全检查对于在输入阶段识别潜在问题至关重要。通过这样做,它可以防止不必要的计算,并立即向用户提供反馈。
虽然使用禁止列表来查找有害 SQL 操作是防止破坏性查询的一种快速方法,但在处理 T-SQL 等复杂 SQL 后端时,维护一个全面的禁止列表可能会变得难以管理。另一种方法是采用允许列表,将查询限制为仅允许安全的操作(例如,SELECT、JOIN)。此方法通过缩小允许的操作范围而不是尝试阻止每个风险命令来确保更 robust 的解决方案。
要实现企业级的解决方案,该项目可以利用 Unity Catalog 等框架,这些框架提供了一种集中且 robust 的方法来管理与安全相关的函数,例如 AI 工作流的 `pre_safety_check`。通过在此类框架中注册和管理可重用函数,可以跨所有 AI 工作流强制执行一致且可靠的行为,从而增强安全性和可扩展性。
此外,该节点利用 LLM 来分析输入中的攻击性或不当内容。如果检测到不安全查询或不当内容,状态将被标记为错误标志,并提供透明的反馈,从而及早保护工作流免受恶意或破坏性元素的侵害。
- 示例
-
禁止的操作
-
输入:“DROP TABLE customers;”
-
响应:“您的查询包含禁止的 SQL 操作,无法处理。”
-
输入:“SELECT _ FROM orders;”
-
响应:“允许查询。”
-
-
不当内容
- 输入:“显示客户姓名类似于‘John the Idiot’的订单;”
- 响应:“您的查询包含不当内容,无法处理。”
- 输入:“按地区查找总销售额。”
- 响应:“输入是安全的,可以处理。”
- 代码
def pre_safety_check(state: GraphState) -> GraphState:
"""
Perform safety checks on the user input to ensure that no dangerous SQL operations
or inappropriate content is present. The function checks for SQL operations like
DELETE, DROP, and others, and also evaluates the input for toxic or unsafe content.
Args:
state (GraphState): The current graph state containing the translated user input.
Returns:
GraphState: The updated state with error status and messages if any issues are found.
"""
_logger.info("Performing safety check.")
translated_input = state["translated_input"]
messages = state["messages"]
error = "no"
# List of disallowed SQL operations (e.g., DELETE, DROP)
disallowed_operations = ['CREATE', 'DELETE', 'DROP', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
pattern = re.compile(r'\b(' + '|'.join(disallowed_operations) + r')\b', re.IGNORECASE)
# Check if the input contains disallowed SQL operations
if pattern.search(translated_input):
_logger.warning("Input contains disallowed SQL operations. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your query contains disallowed SQL operations and cannot be processed.")]
else:
# Check if the input contains inappropriate content
safety_prompt = f"""
Analyze the following input for any toxic or inappropriate content.
Respond with only "safe" or "unsafe", and nothing else.
Input:
{translated_input}
"""
safety_invoke = llm.invoke(safety_prompt)
safety_response = safety_invoke.content.strip().lower() # Access the 'content' attribute and strip any extra spaces
if safety_response == "safe":
_logger.info("Input is safe to process.")
else:
_logger.warning("Input contains inappropriate content. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your query contains inappropriate content and cannot be processed.")]
# Update state with error status and messages
state["error"] = error
state["messages"] = messages
return state
3. 模式提取
`schema_extract` 节点通过查询元数据动态检索数据库模式,包括表名和列详细信息。格式化的模式存储在状态中,从而在适应当前数据库结构的同时验证用户查询。
- 示例
- 输入:请求提取模式。
模式输出- Customers(CustomerID (INTEGER), CustomerName (TEXT), ContactName (TEXT), Address (TEXT), City (TEXT), PostalCode (TEXT), Country (TEXT))
- Orders(OrderID (INTEGER), CustomerID (INTEGER), OrderDate (TEXT))
- OrderDetails(OrderDetailID (INTEGER), OrderID (INTEGER), ProductID (INTEGER), Quantity (INTEGER))
- Products(ProductID (INTEGER), ProductName (TEXT), Price (REAL))
- 输入:请求提取模式。
- 代码
def schema_extract(state: GraphState) -> GraphState:
"""
Extracts the database schema, including all tables and their respective columns,
from the connected SQLite database. This function retrieves the list of tables and
iterates through each table to gather column definitions (name and data type).
Args:
state (GraphState): The current graph state, which will be updated with the database schema.
Returns:
GraphState: The updated state with the extracted database schema.
"""
_logger.info("Extracting database schema.")
# Extract the schema from the database
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema_details = []
# Loop through each table and retrieve column information
for table_name_tuple in tables:
table_name = table_name_tuple[0]
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
# Format column definitions
column_defs = ', '.join([f"{col[1]} ({col[2]})" for col in columns])
schema_details.append(f"- {table_name}({column_defs})")
# Save the schema in the state
database_schema = '\n'.join(schema_details)
state["database_schema"] = database_schema
_logger.info(f"Database schema extracted:\n{database_schema}")
return state
4. 上下文检查
`context_check` 节点通过将用户查询与提取的数据库模式进行比较来验证用户查询,以确保一致性和相关性。与模式不符的查询将被标记为不相关,从而防止资源浪费并允许用户反馈以重新制定查询。
- 示例
- 输入:“平均订单价值是多少?”
模式匹配:输入与数据库模式相关。 - 输入:“显示库存表中的数据。”
响应:“您的问题与数据库无关,无法处理。”
- 输入:“平均订单价值是多少?”
- 代码
def context_check(state: GraphState) -> GraphState:
"""
Checks whether the user's input is relevant to the database schema by comparing
the user's question with the database schema. Uses a language model to determine if
the question can be answered using the provided schema.
Args:
state (GraphState): The current graph state, which contains the translated input
and the database schema.
Returns:
GraphState: The updated state with error status and messages if the input is irrelevant.
"""
_logger.info("Performing context check.")
# Extract relevant data from the state
translated_input = state["translated_input"]
messages = state["messages"]
error = "no"
database_schema = state["database_schema"] # Get the schema from the state
# Use the LLM to determine if the input is relevant to the database schema
context_prompt = f"""
Determine whether the following user input is a question that can be answered using the database schema provided below.
Respond with only "relevant" if the input is relevant to the database schema, or "irrelevant" if it is not.
User Input:
{translated_input}
Database Schema:
{database_schema}
"""
# Call the LLM for context check
llm_invoke = llm.invoke(context_prompt)
llm_response = llm_invoke.content.strip().lower() # Access the 'content' attribute and strip any extra spaces and lower case
# Process the response from the LLM
if llm_response == "relevant":
_logger.info("Input is relevant to the database schema.")
else:
_logger.info("Input is not relevant. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your question is not related to the database and cannot be processed.")]
# Update the state with error and messages
state["error"] = error
state["messages"] = messages
return state
5. 生成
`generate` 节点通过从向量存储中检索相关文档并利用预定义的 SQL 生成链,从自然语言输入构建 SQL 查询。它使查询与用户的意图和模式上下文保持一致,并使用生成的 SQL 及其描述更新状态。
- 示例
- 输入:“查找总销售额。”
生成的 SQL:“SELECT SUM(Products.Price * OrderDetails.Quantity) AS TotalSales FROM OrderDetails LEFT JOIN Products ON OrderDetails.ProductID = Products.ProductID;” - 输入:“列出纽约的所有客户。”
生成的 SQL:“SELECT name FROM customers WHERE location = 'New York';”
- 输入:“查找总销售额。”
- 代码
def generate(state: GraphState) -> GraphState:
"""
Generates an SQL query based on the user's input. The node retrieves relevant documents from
the vector store and uses a generation chain to produce an SQL query.
Args:
state (GraphState): The current graph state, which contains the translated input and
other relevant data such as messages and iteration count.
Returns:
GraphState: The updated state with the generated SQL query and related messages.
"""
_logger.info("Generating SQL query.")
# Extract relevant data from the state
messages = state["messages"]
iterations = state["iterations"]
translated_input = state["translated_input"]
database_schema = state["database_schema"]
# Retrieve relevant documents from the vector store based on the translated user input
docs = vector_store.similarity_search(translated_input, k=4)
retrieved_docs = "\n\n".join([doc.page_content for doc in docs])
# Generate the SQL query using the SQL generation chain
sql_solution = sql_gen_chain.invoke(
{
"retrieved_docs": retrieved_docs,
"database_schema": database_schema,
"messages": [("user", translated_input)],
}
)
# Save the generated SQL query in the state
messages += [
(
"assistant",
f"{sql_solution.description}\nSQL Query:\n{sql_solution.sql_code}",
)
]
iterations += 1
# Log the generated SQL query
_logger.info("Generated SQL query:\n%s", sql_solution.sql_code)
# Update the state with the generated SQL query and updated message list
state["generation"] = sql_solution
state["messages"] = messages
state["iterations"] = iterations
return state
6. 后安全检查
`post_safety_check` 节点通过执行最终的有害 SQL 命令验证来确保生成的 SQL 查询是安全的。虽然之前的预安全检查可以识别用户输入中的禁止操作,但此后安全检查会验证生成后产生的 SQL 查询是否符合安全指南。这种两步方法可以确保即使在查询生成过程中无意中引入了禁止的操作,也可以被捕获并标记。如果检测到不安全查询,该节点将停止工作流,将状态更新为错误标志,并向用户提供反馈。
- 示例
- 禁止的操作
- 生成的查询:“DROP TABLE orders;”
- 响应:“生成的 SQL 查询包含禁止的 SQL 操作:DROP,无法处理。”
- 生成的查询:“SELECT name FROM customers;”
- 响应:“查询有效。”
- 代码
def post_safety_check(state: GraphState) -> GraphState:
"""
Perform safety checks on the generated SQL query to ensure that it doesn't contain disallowed operations
such as CREATE, DELETE, DROP, etc. This node checks the SQL query generated earlier in the workflow.
Args:
state (GraphState): The current graph state containing the generated SQL query.
Returns:
GraphState: The updated state with error status and messages if any issues are found.
"""
_logger.info("Performing post-safety check on the generated SQL query.")
# Retrieve the generated SQL query from the state
sql_solution = state.get("generation", {})
sql_query = sql_solution.get("sql_code", "").strip()
messages = state["messages"]
error = "no"
# List of disallowed SQL operations
disallowed_operations = ['CREATE', 'DELETE', 'DROP', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
pattern = re.compile(r'\b(' + '|'.join(disallowed_operations) + r')\b', re.IGNORECASE)
# Check if the generated SQL query contains disallowed SQL operations
found_operations = pattern.findall(sql_query)
if found_operations:
_logger.warning(
"Generated SQL query contains disallowed SQL operations: %s. Halting the workflow.",
", ".join(set(found_operations))
)
error = "yes"
messages += [("assistant", f"The generated SQL query contains disallowed SQL operations: {', '.join(set(found_operations))} and cannot be processed.")]
else:
_logger.info("Generated SQL query passed the safety check.")
# Update state with error status and messages
state["error"] = error
state["messages"] = messages
return state
7. SQL 检查
`sql_check` 节点通过在事务性保存点(transactional savepoint)内执行生成的 SQL 查询,来确保其安全性和语法有效性。验证后,任何更改都将被回滚,错误会被标记,并提供详细反馈以维护查询的完整性。
- 示例
- 输入 SQL:“SELECT name FROM customers WHERE city = 'New York';”
验证:查询有效。 - 输入 SQL:“SELECT MONTH(date) AS month, SUM(total) AS total_sales FROM orders GROUP BY MONTH(date);”
响应:“您的 SQL 查询执行失败:没有这样的函数:MONTH。”
- 输入 SQL:“SELECT name FROM customers WHERE city = 'New York';”
- 代码
def sql_check(state: GraphState) -> GraphState:
"""
Validates the generated SQL query by attempting to execute it on the database.
If the query is valid, the changes are rolled back to ensure no data is modified.
If there is an error during execution, the error is logged and the state is updated accordingly.
Args:
state (GraphState): The current graph state, which contains the generated SQL query
and the messages to communicate with the user.
Returns:
GraphState: The updated state with error status and messages if the query is invalid.
"""
_logger.info("Validating SQL query.")
# Extract relevant data from the state
messages = state["messages"]
sql_solution = state["generation"]
error = "no"
sql_code = sql_solution.sql_code.strip()
try:
# Start a savepoint for the transaction to allow rollback
conn.execute('SAVEPOINT sql_check;')
# Attempt to execute the SQL query
cursor.execute(sql_code)
# Roll back to the savepoint to undo any changes
conn.execute('ROLLBACK TO sql_check;')
_logger.info("SQL query validation: success.")
except Exception as e:
# Roll back in case of error
conn.execute('ROLLBACK TO sql_check;')
_logger.error("SQL query validation failed. Error: %s", e)
messages += [("user", f"Your SQL query failed to execute: {e}")]
error = "yes"
# Update the state with the error status
state["error"] = error
state["messages"] = messages
return state
8. 运行查询
`run_query` 节点执行经过验证的 SQL 查询,连接到数据库以检索结果。它使用查询输出来更新状态,确保数据格式化以供进一步分析或报告,同时实施 robust 的错误处理。
- 示例
- 输入 SQL:“SELECT COUNT(*) FROM Customers WHERE City = 'New York';”
查询结果:“(0,)” - 输入 SQL:“SELECT SUM(Products.Price * OrderDetails.Quantity) AS TotalSales FROM OrderDetails LEFT JOIN Products ON OrderDetails.ProductID = Products.ProductID;”
查询结果:“(6925.0,)”
- 输入 SQL:“SELECT COUNT(*) FROM Customers WHERE City = 'New York';”
- 代码
def run_query(state: GraphState) -> GraphState:
"""
Executes the generated SQL query on the database and retrieves the results if it is a SELECT query.
For non-SELECT queries, commits the changes to the database. If no records are found for a SELECT query,
the `no_records_found` flag is set to True.
Args:
state (GraphState): The current graph state, which contains the generated SQL query and other relevant data.
Returns:
GraphState: The updated state with the query results, or a flag indicating if no records were found.
"""
_logger.info("Running SQL query.")
# Extract the SQL query from the state
sql_solution = state["generation"]
sql_code = sql_solution.sql_code.strip()
results = None
no_records_found = False # Flag to indicate no records found
try:
# Execute the SQL query
cursor.execute(sql_code)
# For SELECT queries, fetch and store results
if sql_code.upper().startswith("SELECT"):
results = cursor.fetchall()
if not results:
no_records_found = True
_logger.info("SQL query execution: success. No records found.")
else:
_logger.info("SQL query execution: success.")
else:
# For non-SELECT queries, commit the changes
conn.commit()
_logger.info("SQL query execution: success. Changes committed.")
except Exception as e:
_logger.error("SQL query execution failed. Error: %s", e)
# Update the state with results and flag for no records found
state["results"] = results
state["no_records_found"] = no_records_found
return state
决策步骤:确定下一步操作
`decide_next_step` 函数充当工作流中的控制点,根据当前状态决定下一步应执行的操作。它评估 `error` 状态和迄今为止执行的迭代次数,以确定是应运行查询、结束工作流,还是系统应重试生成 SQL 查询。
-
流程
- 如果没有错误(`error == "no"`),系统将继续运行 SQL 查询。
- 如果已达到最大迭代次数(`max_iterations`),工作流将结束。
- 如果发生错误且未达到最大迭代次数,系统将重试查询生成。
-
示例工作流决策
- 无错误,继续执行查询:如果在先前步骤中未发现错误,工作流将继续运行查询。
- 达到最大迭代次数,结束工作流:如果工作流已尝试了设定的次数(`max_iterations`),则终止。
- 检测到错误,重试 SQL 生成:如果发生错误且系统尚未达到重试限制,它将尝试重新生成 SQL 查询。
-
代码
def decide_next_step(state: GraphState) -> str:
"""
Determines the next step in the workflow based on the current state, including whether the query
should be run, the workflow should be finished, or if the query generation needs to be retried.
Args:
state (GraphState): The current graph state, which contains error status and iteration count.
Returns:
str: The next step in the workflow, which can be "run_query", "generate", or END.
"""
_logger.info("Deciding next step based on current state.")
error = state["error"]
iterations = state["iterations"]
if error == "no":
_logger.info("Error status: no. Proceeding with running the query.")
return "run_query"
elif iterations >= max_iterations:
_logger.info("Maximum iterations reached. Ending the workflow.")
return END
else:
_logger.info("Error detected. Retrying SQL query generation.")
return "generate"
工作流编排和条件逻辑
为了定义我们系统中任务的编排,我们使用 `StateGraph` 类来构建工作流图。每个任务都表示为一个节点,任务之间的转换定义为边。条件边用于根据工作流的状态控制流程。
def get_workflow(conn, cursor, vector_store):
"""Define and compile the LangGraph workflow."""
# Max iterations: defines how many times the workflow should retry in case of errors
max_iterations = 3
# SQL generation chain: this is a chain that will generate SQL based on retrieved docs
sql_gen_chain = get_sql_gen_chain()
# Initialize OpenAI LLM for translation and safety checks
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
# Define the individual nodes of the workflow
... # Insert nodes code defined above here
# Build the workflow graph
workflow = StateGraph(GraphState)
# Define workflow nodes
workflow.add_node("translate_input", translate_input) # Translate user input to structured format
workflow.add_node("pre_safety_check", pre_safety_check) # Perform a pre-safety check on input
workflow.add_node("schema_extract", schema_extract) # Extract the database schema
workflow.add_node("context_check", context_check) # Validate input relevance to context
workflow.add_node("generate", generate) # Generate SQL query
workflow.add_node("post_safety_check", post_safety_check) # Perform a post-safety check on generated SQL query
workflow.add_node("sql_check", sql_check) # Validate the generated SQL query
workflow.add_node("run_query", run_query) # Execute the SQL query
# Define workflow edges
workflow.add_edge(START, "translate_input") # Start at the translation step
workflow.add_edge("translate_input", "pre_safety_check") # Move to safety checks
# Conditional edge after safety check
workflow.add_conditional_edges(
"pre_safety_check", # Start at the pre_safety_check node
lambda state: "schema_extract" if state["error"] == "no" else END, # Decide next step
{"schema_extract": "schema_extract", END: END} # Map states to nodes
)
workflow.add_edge("schema_extract", "context_check") # Proceed to context validation
# Conditional edge after context check
workflow.add_conditional_edges(
"context_check", # Start at the context_check node
lambda state: "generate" if state["error"] == "no" else END, # Decide next step
{"generate": "generate", END: END}
)
workflow.add_edge("generate", "post_safety_check") # Proceed to post-safety check
# Conditional edge after post-safety check
workflow.add_conditional_edges(
"post_safety_check", # Start at the post_safety_check node
lambda state: "sql_check" if state["error"] == "no" else END, # If no error, proceed to sql_check, else END
{"sql_check": "sql_check", END: END},
)
# Conditional edge after SQL validation
workflow.add_conditional_edges(
"sql_check", # Start at the sql_check node
decide_next_step, # Function to determine the next step
{
"run_query": "run_query", # If SQL is valid, execute the query
"generate": "generate", # If retry is needed, go back to generation
END: END # Otherwise, terminate the workflow
}
)
workflow.add_edge("run_query", END) # Final step is to end the workflow
# Compile and return the workflow application
app = workflow.compile()
return app
- 开始到 `translate_input`:
工作流首先将用户输入翻译成结构化格式。 - `translate_input` 到 `pre_safety_check`:
翻译后,工作流继续检查输入的安全性。 - `pre_safety_check` 条件规则:
- 如果输入通过了预安全检查(`state["error"] == "no"`),工作流将移至 `schema_extract`。
- 如果输入未能通过预安全检查,工作流将终止(`END`)。
- `schema_extract` 到 `context_check`:
模式被提取,然后工作流验证输入与数据库上下文的相关性。 - `context_check` 条件规则:
- 如果输入相关(`state["error"] == "no"`),工作流将移至 `generate`。
- 如果无关,工作流将终止(`END`)。
- `generate` 到 `post_safety_check`:
工作流生成 SQL 查询并将其发送进行验证。\ - `post_safety_check` 条件规则:- 如果输入通过了后安全检查(`state["error"] == "no"`),工作流将移至 `sql_check`。- 如果输入未能通过后安全检查,工作流将终止(`END`)。
- `sql_check` 条件规则:
- 如果查询有效,工作流将继续执行 `run_query`。
- 如果查询需要调整且迭代次数少于 3 次,工作流将循环回到 `generate`。
- 如果验证失败且迭代次数超过 3 次,工作流将终止(`END`)。
- `run_query` 到 `END`:
查询执行后,工作流结束。
上述图表提供了 LangGraph 节点和边的表示。
在 MLflow 中记录模型
现在我们已经使用 LangGraph 构建了一个多语言查询引擎,我们就可以使用 MLflow 的 从代码创建模型 功能来记录该模型。将模型记录到 MLflow 中,使我们可以将多语言查询引擎视为一个传统的机器学习模型,从而能够无缝地进行跟踪、版本控制和打包,以便在各种不同的服务基础架构上进行部署。MLflow 的从代码创建模型策略(我们记录代表模型的代码)与 MLflow 的基于对象的日志记录(创建、序列化模型对象并将其记录为 pickle 或 JSON 对象)形成对比。
步骤 1:创建我们的从代码创建模型文件
到目前为止,我们已经在名为 workflow.py 的文件中定义了 get_workflow 函数。在此步骤中,我们将创建一个新文件 sql_model.py,其中包含 SQLGenerator 类。此脚本将
- 从
workflow.py导入get_workflow函数。 - 将
SQLGenerator类定义为PythonModel,包括一个利用get_workflow函数初始化 LangGraph 工作流的 predict 方法。 - 使用
mlflow.models.set_model将SQLGenerator PythonModel类指定为 MLflow 的目标模型。
import mlflow
from definitions import REMOTE_SERVER_URI
from workflow import get_workflow
mlflow.set_tracking_uri(REMOTE_SERVER_URI)
class SQLGenerator(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input):
return get_workflow(
model_input["conn"], model_input["cursor"], model_input["vector_store"]
)
mlflow.models.set_model(SQLGenerator())
步骤 2:使用从代码创建模型功能进行记录
在定义了 SQLGenerator 自定义 Python 模型后,下一步是使用 MLflow 的从代码创建模型功能来记录它。这涉及使用 log model 标准 API,指定 sql_model.py 脚本的路径,并使用 code_paths 参数将 workflow.py 包含为依赖项。这种方法可确保在其他环境或另一台机器上加载模型时,所有必需的代码文件都可用。
import mlflow
from definitions import (
EXPERIMENT_NAME,
MODEL_ALIAS,
REGISTERED_MODEL_NAME,
REMOTE_SERVER_URI,
)
from mlflow import MlflowClient
client = MlflowClient(tracking_uri=REMOTE_SERVER_URI)
mlflow.set_tracking_uri(REMOTE_SERVER_URI)
mlflow.set_experiment(EXPERIMENT_NAME)
with mlflow.start_run():
logged_model_info = mlflow.pyfunc.log_model(
python_model="sql_model.py",
artifact_path="sql_generator",
registered_model_name=REGISTERED_MODEL_NAME,
code_paths=["workflow.py"],
)
client.set_registered_model_alias(
REGISTERED_MODEL_NAME, MODEL_ALIAS, logged_model_info.registered_model_version
)
在 MLflow UI 中,存储的模型将 sql_model.py 和 workflow.py 脚本作为构件包含在运行中。这种从代码记录功能不仅记录了模型的参数和指标,还捕获了定义其功能的代码。这确保了可观测性、无缝跟踪以及直接通过 UI 进行的简单调试。但是,至关重要的是要确保永远不要将 API 密钥或凭据等敏感元素硬编码到这些脚本中。由于代码是按原样存储的,因此包含的任何敏感信息都可能导致令牌泄露并带来安全风险。相反,应使用环境变量、机密管理系统或其他安全方法来安全地管理敏感数据。

在 main.py 中使用记录的多语言查询引擎
记录模型后,可以使用模型 URI 和标准的 mlflow.pyfunc.load_model API 从 MLflow 跟踪服务器加载它。加载模型时,将执行 workflow.py 脚本和 sql_model.py 脚本,确保在调用 predict 方法时 get_workflow 函数可用。
通过执行下面的代码,我们演示了我们的多语言查询引擎能够执行自然语言到 SQL 的生成和查询。
import os
import logging
import mlflow
from database import setup_database
from definitions import (
EXPERIMENT_NAME,
MODEL_ALIAS,
REGISTERED_MODEL_NAME,
REMOTE_SERVER_URI,
)
from dotenv import load_dotenv
from vector_store import setup_vector_store
mlflow.set_tracking_uri(REMOTE_SERVER_URI)
mlflow.set_experiment(EXPERIMENT_NAME)
mlflow.langchain.autolog()
# Initialize the logger
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
_logger.addHandler(handler)
def main():
# Load environment variables from .env file
load_dotenv()
# Access secrets using os.getenv
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
# Setup database and vector store
conn = setup_database()
cursor = conn.cursor()
vector_store = setup_vector_store()
# Load the model
model_uri = f"models:/{REGISTERED_MODEL_NAME}@{MODEL_ALIAS}"
model = mlflow.pyfunc.load_model(model_uri)
model_input = {"conn": conn, "cursor": cursor, "vector_store": vector_store}
app = model.predict(model_input)
# save image
app.get_graph().draw_mermaid_png(
output_file_path="sql_agent_with_safety_checks.png"
)
# Example user interaction
_logger.info("Welcome to the SQL Assistant!")
while True:
question = input("\nEnter your SQL question (or type 'exit' to quit): ")
if question.lower() == "exit":
break
# Initialize the state with all required keys
initial_state = {
"messages": [("user", question)],
"iterations": 0,
"error": "",
"results": None,
"generation": None,
"no_records_found": False,
"translated_input": "", # Initialize translated_input
}
solution = app.invoke(initial_state)
# Check if an error was set during the safety check
if solution["error"] == "yes":
_logger.info("\nAssistant Message:\n")
_logger.info(solution["messages"][-1][1]) # Display the assistant's message
continue # Skip to the next iteration
# Extract the generated SQL query from solution["generation"]
sql_query = solution["generation"].sql_code
_logger.info("\nGenerated SQL Query:\n")
_logger.info(sql_query)
# Extract and display the query results
if solution.get("no_records_found"):
_logger.info("\nNo records found matching your query.")
elif "results" in solution and solution["results"] is not None:
_logger.info("\nQuery Results:\n")
for row in solution["results"]:
_logger.info(row)
else:
_logger.info("\nNo results returned or query did not execute successfully.")
_logger.info("Goodbye!")
if __name__ == "__main__":
main()
项目文件结构
项目遵循简单的文件结构
MLflow 跟踪
MLflow 自动跟踪 为 LangChain、OpenAI、LlamaIndex、DSPy 和 AutoGen 等各种 GenAI 库提供了完全自动化的集成。由于我们的 AI 工作流是使用 LangGraph 构建的,我们可以通过启用 mlflow.langchain.autolog() 来激活 LangChain 的自动跟踪。
通过 LangChain 自动日志记录,每当在链上调用调用 API 时,跟踪都会自动记录到活动的 MLflow 实验中。这种无缝集成可确保捕获每次交互,以便进行监控和分析。
在 MLflow 中查看跟踪
可以通过导航到目标 MLflow 实验并单击“跟踪”选项卡来轻松访问跟踪。进入后,选择特定的跟踪即可提供详细的执行信息。
每个跟踪包括
- 执行图:工作流步骤的可视化。
- 输入和输出:每个步骤处理的数据的详细日志。
通过利用 MLflow 跟踪,我们可以深入了解整个图的执行情况。AI 工作流图通常感觉像一个黑匣子,使得调试和理解每个步骤发生的事情变得很困难。但是,只需一行代码即可启用跟踪,MLflow 即可提供对工作流的清晰详细的洞察,使开发人员能够有效地调试、监控和优化图的每个节点,确保我们的多语言查询引擎保持透明、可审计和可扩展。

结论
在本文中,我们探讨了使用 LangGraph 和 MLflow 构建和管理多语言查询引擎的过程。通过将 LangGraph 的动态 AI 工作流与 MLflow 强大的生命周期管理和跟踪功能相结合,我们创建了一个系统,该系统不仅能够提供准确高效的自然语言到 SQL 的生成和执行,而且还透明且可扩展。


