如果您正在寻找构建一个结合了自然语言到 SQL 生成和查询执行功能,同时充分利用 MLflow 特性的多语言查询引擎,那么这篇博文将是您的指南。我们将探讨如何利用 MLflow 代码模型 (Models from Code) 实现 AI 工作流的无缝跟踪和版本控制。此外,我们将深入探讨 MLflow 跟踪 (Tracing) 特性,该特性通过跟踪每个中间步骤的输入、输出和元数据,为 AI 工作流的众多不同组件引入了可观察性。
介绍
SQL 是管理和访问关系数据库中数据的基本技能。然而,构建复杂的 SQL 查询以回答复杂的数据问题可能具有挑战性且耗时。这种复杂性可能使得有效充分利用数据变得困难。自然语言到 SQL (NL2SQL) 系统通过提供自然语言到 SQL 命令的翻译来帮助解决这个问题,允许非技术人员与数据交互:用户只需用他们熟悉的自然语言提问,这些系统就会帮助他们返回适当的信息。
然而,在创建 NL2SQL 系统时仍然存在许多问题,例如语义模糊、模式映射或错误处理和用户反馈。因此,在构建此类系统时,非常重要的是我们必须设置一些防护措施,而不是完全依赖于 LLM。
在这篇博文中,我们将引导您完成构建多语言查询引擎的过程。该引擎支持多种语言的自然语言输入,根据翻译后的用户输入生成 SQL 查询,并执行查询。让我们来看一个例子:使用一个包含公司客户、产品和订单信息的数据库,用户可以用任何语言提问,例如 “Quantos clientes temos por país?” (葡萄牙语,意为“我们每个国家有多少客户?”)。AI 工作流将输入翻译成英文,输出 “How many customers do we have per country?”。然后它对输入进行安全验证,检查是否可以使用数据库模式回答问题,生成适当的 SQL 查询(例如,SELECT COUNT(CustomerID) AS NumberOfCustomers, Country FROM Customers GROUP BY Country;
),并验证查询以确保没有有害命令(例如,DROP)。最后,它针对数据库执行查询以检索结果。
我们将首先演示如何利用 LangGraph 的能力构建一个动态的 AI 工作流。该工作流集成了 OpenAI 和外部数据源(例如向量存储和 SQLite 数据库),用于处理用户输入、执行安全检查、查询数据库以及生成有意义的响应。
在整篇文章中,我们将利用 MLflow 代码模型 (Models from Code) 特性,实现 AI 工作流的无缝跟踪和版本控制。此外,我们将深入探讨 MLflow 跟踪 (Tracing) 特性,该特性旨在通过跟踪与每个中间步骤关联的输入、输出和元数据,增强 AI 工作流众多不同组件的可观察性。这有助于轻松识别错误和意外行为,提供工作流更大的透明度。
先决条件
要设置和运行此项目,请确保安装了以下 Python 包:
faiss-cpu
langchain
langchain-core
langchain-openai
langgraph
langchain-community
pydantic >=2
typing_extensions
python-dotenv
此外,需要一个 MLflow 跟踪服务器 (Tracking Server) 来有效记录和管理实验、模型和跟踪。对于本地设置,请参阅 MLflow 官方文档中关于配置简单 MLflow 跟踪服务器的说明。
最后,确保您的 OpenAI API 密钥保存在项目目录中的 .env 文件中。这允许应用程序安全访问构建 AI 工作流所需的 OpenAI 服务。.env 文件应包含类似以下的一行:
OPENAI_API_KEY=your_openai_api_key
使用 LangGraph 的多语言查询引擎
多语言查询引擎利用 LangGraph 库,这是一个 AI 编排工具,旨在为由 LLMs 驱动的应用程序创建有状态、多代理和循环图架构。
与其他 AI 编排工具相比,LangGraph 提供三个核心优势:循环、可控性和持久性。它允许定义带有循环的 AI 工作流,这对于实现重试机制至关重要,例如多语言查询引擎中的 SQL 查询生成重试(当验证失败时,查询会循环回再次生成)。这使得 LangGraph 成为构建我们的多语言查询引擎的理想工具。
LangGraph 主要特性:
-
有状态架构 (Stateful Architecture):引擎维护图执行状态的动态快照。此快照作为跨节点共享的资源, enabling efficient decision-making and real-time updates at each node execution.
-
多代理设计 (Multi-Agent Design):AI 工作流在整个工作流中包括与 OpenAI 和其他外部工具的多次交互。
-
循环图结构 (Cyclical Graph Structure):图的循环性质引入了强大的重试机制。此机制通过在需要时循环回先前阶段来动态处理故障,确保图的持续执行。(此机制的详细信息将在后面讨论。)
AI 工作流概览
多语言查询引擎的高级 AI 工作流由相互连接的节点和边缘组成,每个节点和边缘都代表一个关键阶段:
-
翻译节点 (Translation Node):将用户输入转换为英文。
-
预安全检查 (Pre-safety Check):确保用户输入不含不良或不当内容,且不包含有害的 SQL 命令(例如
DELETE
、DROP
)。 -
数据库模式提取 (Database Schema Extraction):检索目标数据库的模式,以了解其结构和可用数据。
-
相关性验证 (Relevancy Validation):对照数据库模式验证用户输入,以确保与数据库上下文一致。
-
SQL 查询生成 (SQL Query Generation):根据用户输入和当前数据库模式生成 SQL 查询。
-
后安全检查 (Post-safety Check):确保生成的 SQL 查询不包含有害的 SQL 命令(例如
DELETE
、DROP
)。 -
动态状态评估 (Dynamic State Evaluation):根据当前状态确定下一步行动。如果 SQL 查询验证失败,它会循环回阶段 5 以重新生成查询。
-
查询执行和结果检索 (Query Execution and Result Retrieval):执行 SQL 查询,如果是
SELECT
语句,则返回结果。 -
重试机制在阶段 8 引入,系统在此动态评估当前图状态。具体来说,当 SQL 查询验证节点(阶段 7)检测到问题时,状态会触发循环回 SQL 生成节点(阶段 5),以便进行新的 SQL 生成尝试(最多 3 次尝试)。
重试机制在阶段8引入,系统在此动态评估当前图状态。具体来说,当SQL查询验证节点(阶段7)检测到问题时,状态会触发一个循环回到SQL生成节点(阶段5),以进行一次新的SQL生成尝试(最多尝试3次)。
组件
多语言查询引擎与多个外部组件交互,将自然语言用户输入转换为 SQL 查询并以安全可靠的方式执行。在本节中,我们将详细介绍关键的 AI 工作流组件:OpenAI、向量存储、SQLite 数据库和 SQL 生成链。
OpenAI
OpenAI,更具体地说是 gpt-4o-mini
语言模型,在工作流的多个阶段发挥着关键作用。它提供了所需智能,用于:
-
翻译 (Translation):将用户输入翻译成英文。如果文本已经是英文,它会简单地重复输入。
-
安全检查 (Safety Checks):分析用户输入,确保不包含不良或不当内容。
-
相关性检查 (Relevance Checks):评估用户的问题是否与数据库模式相关。
-
SQL 生成 (SQL Generation):根据用户输入、SQL 生成文档和数据库模式生成有效且可执行的 SQL 查询。
OpenAI 实现的详细信息将在后面的节点描述部分提供。
FAISS 向量存储
为了构建一个能够生成准确且可执行的 SQL 查询的有效自然语言到 SQL 引擎,我们利用了 Langchain 的 FAISS 向量存储特性。这种设置允许系统从先前存储在向量数据库中的W3Schools SQL 文档中搜索和提取 SQL 查询生成指南,从而提高 SQL 查询生成的成功率。
出于演示目的,我们使用的是 FAISS,一个内存向量存储,其中向量直接存储在 RAM 中。这提供了快速访问,但也意味着数据在运行之间不会持久化。对于更具可扩展性的解决方案,允许嵌入在多个项目之间存储和共享,我们推荐替代方案,例如 AWS OpenSearch、Vertex AI Vector Search、Azure Vector Search 或 Mosaic AI Vector Search。这些基于云的解决方案提供持久存储、自动伸缩以及与其他云服务的无缝集成,非常适合大规模应用程序。
步骤 1:加载 SQL 文档
创建带有 SQL 查询生成指南的 FAISS 向量存储的第一步是使用 LangChain 的 RecursiveUrlLoader
从 W3Schools SQL 页面加载 SQL 文档。此工具检索文档,允许我们将其用作引擎的知识库。
步骤 2:将文本分割成可管理的块
加载的 SQL 文档是一个长文本,难以被 LLM 有效地摄取。为了解决这个问题,下一步涉及使用 Langchain 的 RecursiveCharacterTextSplitter
将文本分割成更小的、可管理的块。通过将文本分割成 500 个字符的块并设置 50 个字符的重叠,我们确保语言模型有足够的上下文,同时最大程度地减少丢失跨块重要信息的风险。split_text
方法应用此分割过程,将结果片段存储在名为 'documents' 的列表中。
步骤 3:生成嵌入模型
第三步是创建一个模型,将这些块转换为嵌入(每个文本块的向量化数值表示)。嵌入使系统能够比较块与用户输入之间的相似性,从而便于检索最相关的匹配项以用于 SQL 查询生成。
步骤 4:在 FAISS 向量存储中创建和存储嵌入
最后,我们使用 FAISS 创建和存储嵌入。FAISS.from_texts
方法接受所有块,计算它们的嵌入,并将它们存储在一个高速可搜索的向量数据库中。这个可搜索的数据库允许引擎高效地检索相关的 SQL 指南,显著提高了可执行 SQL 查询生成的成功率。
import logging
import os
from bs4 import BeautifulSoup as Soup
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
def setup_vector_store(logger: logging.Logger):
"""Setup or load the vector store."""
if not os.path.exists("data"):
os.makedirs("data")
vector_store_dir = "data/vector_store"
if os.path.exists(vector_store_dir):
# Load the vector store from disk
logger.info("Loading vector store from disk...")
vector_store = FAISS.load_local(
vector_store_dir,
OpenAIEmbeddings(),
allow_dangerous_deserialization=True,
)
else:
logger.info("Creating new vector store...")
# Load SQL documentation
url = "https://w3schools.org.cn/sql/"
loader = RecursiveUrlLoader(
url=url, max_depth=2, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()
# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
)
documents = []
for doc in docs:
splits = text_splitter.split_text(doc.page_content)
for i, split in enumerate(splits):
documents.append(
{
"content": split,
"metadata": {"source": doc.metadata["source"], "chunk": i},
}
)
# Compute embeddings and create vector store
embedding_model = OpenAIEmbeddings()
vector_store = FAISS.from_texts(
[doc["content"] for doc in documents],
embedding_model,
metadatas=[doc["metadata"] for doc in documents],
)
# Save the vector store to disk
vector_store.save_local(vector_store_dir)
logger.info("Vector store created and saved to disk.")
return vector_store
SQLite 数据库
SQLite 数据库是多语言查询引擎的关键组件,充当结构化数据存储库。SQLite 提供了一个轻量级、快速且独立的关联数据库引擎,无需服务器设置或安装。其紧凑的大小(小于 500KB)和零配置特性使其非常易于使用,而其平台无关的数据库格式确保了跨不同系统的无缝可移植性。作为本地磁盘数据库,SQLite 是避免设置 MySQL 或 PostgreSQL 复杂性的理想选择,同时仍然提供可靠、功能齐全且性能出色的 SQL 引擎。
SQLite 数据库通过以下方式支持高效的 SQL 查询生成、验证和执行:
-
模式提取 (Schema Extraction):为用户输入上下文验证(阶段 4)和可执行 SQL 查询生成(阶段 5)提供模式信息。
-
查询执行 (Query Execution):在回滚安全环境中执行 SQL 查询,用于验证阶段(阶段 7)和查询执行阶段(阶段 9),获取
SELECT
语句的结果并为其他查询类型提交更改。
SQLite 数据库初始化
数据库在 AI 工作流初始化时使用 setup_database
函数进行初始化。此过程包括:
-
设置 SQLite 数据库连接 (Setting the SQLite Database Connection):建立与 SQLite 数据库的连接,启用数据交互。
-
表格创建 (Table Creation):为 AI 工作流定义并创建必要的数据库表格。
-
数据填充 (Data Population):使用示例数据填充表格,以支持查询执行和验证阶段。
import logging
import os
import sqlite3
def create_connection(db_file="data/database.db"):
"""Create a database connection to the SQLite database."""
conn = sqlite3.connect(db_file)
return conn
def create_tables(conn):
"""Create tables in the database."""
cursor = conn.cursor()
# Create Customers table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Customers (
CustomerID INTEGER PRIMARY KEY,
CustomerName TEXT,
ContactName TEXT,
Address TEXT,
City TEXT,
PostalCode TEXT,
Country TEXT
)
"""
)
# Create Orders table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Orders (
OrderID INTEGER PRIMARY KEY,
CustomerID INTEGER,
OrderDate TEXT,
FOREIGN KEY (CustomerID) REFERENCES Customers (CustomerID)
)
"""
)
# Create OrderDetails table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS OrderDetails (
OrderDetailID INTEGER PRIMARY KEY,
OrderID INTEGER,
ProductID INTEGER,
Quantity INTEGER,
FOREIGN KEY (OrderID) REFERENCES Orders (OrderID),
FOREIGN KEY (ProductID) REFERENCES Products (ProductID)
)
"""
)
# Create Products table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Products (
ProductID INTEGER PRIMARY KEY,
ProductName TEXT,
Price REAL
)
"""
)
conn.commit()
def populate_tables(conn):
"""Populate tables with sample data if they are empty."""
cursor = conn.cursor()
# Populate Customers table if empty
cursor.execute("SELECT COUNT(*) FROM Customers")
if cursor.fetchone()[0] == 0:
customers = []
for i in range(1, 51):
customers.append(
(
i,
f"Customer {i}",
f"Contact {i}",
f"Address {i}",
f"City {i % 10}",
f"{10000 + i}",
f"Country {i % 5}",
)
)
cursor.executemany(
"""
INSERT INTO Customers (CustomerID, CustomerName, ContactName, Address, City, PostalCode, Country)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
customers,
)
# Populate Products table if empty
cursor.execute("SELECT COUNT(*) FROM Products")
if cursor.fetchone()[0] == 0:
products = []
for i in range(1, 51):
products.append((i, f"Product {i}", round(10 + i * 0.5, 2)))
cursor.executemany(
"""
INSERT INTO Products (ProductID, ProductName, Price)
VALUES (?, ?, ?)
""",
products,
)
# Populate Orders table if empty
cursor.execute("SELECT COUNT(*) FROM Orders")
if cursor.fetchone()[0] == 0:
orders = []
from datetime import datetime, timedelta
base_date = datetime(2023, 1, 1)
for i in range(1, 51):
order_date = base_date + timedelta(days=i)
orders.append(
(
i,
i % 50 + 1, # CustomerID between 1 and 50
order_date.strftime("%Y-%m-%d"),
)
)
cursor.executemany(
"""
INSERT INTO Orders (OrderID, CustomerID, OrderDate)
VALUES (?, ?, ?)
""",
orders,
)
# Populate OrderDetails table if empty
cursor.execute("SELECT COUNT(*) FROM OrderDetails")
if cursor.fetchone()[0] == 0:
order_details = []
for i in range(1, 51):
order_details.append(
(
i,
i % 50 + 1, # OrderID between 1 and 50
i % 50 + 1, # ProductID between 1 and 50
(i % 5 + 1) * 2, # Quantity between 2 and 10
)
)
cursor.executemany(
"""
INSERT INTO OrderDetails (OrderDetailID, OrderID, ProductID, Quantity)
VALUES (?, ?, ?, ?)
""",
order_details,
)
conn.commit()
def setup_database(logger: logging.Logger):
"""Setup the database and return the connection."""
db_file = "data/database.db"
if not os.path.exists("data"):
os.makedirs("data")
db_exists = os.path.exists(db_file)
conn = create_connection(db_file)
if not db_exists:
logger.info("Setting up the database...")
create_tables(conn)
populate_tables(conn)
else:
logger.info("Database already exists. Skipping setup.")
return conn
SQL 生成链
SQL 生成链(sql_gen_chain
)是我们工作流中自动化 SQL 查询生成的支柱。此链利用 LangChain 的模块化能力和 OpenAI 的高级自然语言处理能力,将用户问题转换为精确且可执行的 SQL 查询。
核心特性:
-
Prompt 驱动生成 (Prompt-Driven Generation):始于精心设计的 Prompt,该 Prompt 集成了数据库模式和文档片段,确保查询的上下文准确性。
-
结构化响应 (Structured Responses):以预定义的格式交付输出,包括:
-
查询目的的描述 (description)。
-
可直接执行的相应 SQL 代码 (SQL code)。
-
-
可适应和可靠 (Adaptable and Reliable):使用
gpt-4o-mini
进行强大、一致的查询生成,最大程度地减少手动操作和错误。
此链是我们工作流中的关键组件,能够实现 SQL 查询生成与下游流程的无缝集成,确保准确性,并显著提高效率。
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
class SQLQuery(BaseModel):
"""Schema for SQL query solutions to questions."""
description: str = Field(description="Description of the SQL query")
sql_code: str = Field(description="The SQL code block")
def get_sql_gen_chain():
"""Set up the SQL generation chain."""
sql_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a SQL assistant with expertise in SQL query generation. \n
Answer the user's question based on the provided documentation snippets and the database schema provided below. Ensure any SQL query you provide is valid and executable. \n
Structure your answer with a description of the query, followed by the SQL code block. Here are the documentation snippets:\n{retrieved_docs}\n\nDatabase Schema:\n{database_schema}""",
),
("placeholder", "{messages}"),
]
)
# Initialize the OpenAI LLM
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
# Create the code generation chain
sql_gen_chain = sql_gen_prompt | llm.with_structured_output(SQLQuery)
return sql_gen_chain
工作流设置与初始化
在深入探讨工作流节点之前,建立必要的组件并定义工作流的结构至关重要。本节解释了基本库的初始化、日志记录和自定义 GraphState
类,以及主要工作流编译函数。
定义 GraphState
GraphState
类是一个自定义的 TypedDict
,用于维护工作流进行过程中的状态信息。它充当跨节点的共享数据结构,确保连续性和一致性。关键字段包括:
error
:跟踪是否发生了错误。messages
:存储用户和系统消息的列表。generation
:保存生成的 SQL 查询。iterations
:跟踪错误情况下的重试尝试次数。results
:存储 SQL 执行结果(如果有)。no_records_found
:标记查询是否未返回记录。translated_input
:包含用户翻译后的输入。database_schema
:维护数据库模式以进行上下文验证。
import logging
import re
from typing import List, Optional
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from sql_generation import get_sql_gen_chain
from typing_extensions import TypedDict
# Initialize the logger
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
_logger.addHandler(handler)
class GraphState(TypedDict):
error: str # Tracks if an error has occurred
messages: List # List of messages (user input and assistant messages)
generation: Optional[str] # Holds the generated SQL query
iterations: int # Keeps track of how many times the workflow has retried
results: Optional[List] # Holds the results of SQL execution
no_records_found: bool # Flag for whether any records were found in the SQL result
translated_input: str # Holds the translated user input
database_schema: str # Holds the extracted database schema for context checking
工作流编译函数
主要函数 get_workflow
负责定义和编译工作流。关键组件包括:
conn
和cursor
:用于数据库连接和查询执行。vector_store
:用于上下文检索的向量数据库。max_iterations
:设置重试尝试的限制以防止无限循环。sql_gen_chain
:从sql_generation
中检索 SQL 生成链,用于根据上下文输入生成 SQL 查询。ChatOpenAI
:初始化 OpenAIgpt-4o-mini
模型,用于安全检查和查询翻译等任务。
def get_workflow(conn, cursor, vector_store):
"""Define and compile the LangGraph workflow."""
# Max iterations: defines how many times the workflow should retry in case of errors
max_iterations = 3
# SQL generation chain: this is a chain that will generate SQL based on retrieved docs
sql_gen_chain = get_sql_gen_chain()
# Initialize OpenAI LLM for translation and safety checks
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
# Define the individual nodes of the workflow
此函数充当使用 StateGraph
创建完整工作流的入口点。工作流中的各个节点将在后续部分定义并连接。
节点描述
1. 翻译输入
translate_input
节点将用户查询翻译成英文,以标准化处理并确保与下游节点的兼容性。将用户输入作为 AI 工作流的第一步进行翻译,确保了任务隔离并提高了可观察性。任务隔离通过将翻译与用户输入安全验证和 SQL 生成等其他下游任务分开,简化了工作流。可观察性提高则在 MLflow 中提供了清晰的跟踪,使得调试和监控过程更容易。
- 示例
- Input: "Quantos pedidos foram realizados em Novembro?"
- Translated: "How many orders were made in November?"
- Input: "Combien de ventes avons-nous enregistrées en France ?"
- Translated: "How many sales did we record in France?"
- 代码
def translate_input(state: GraphState) -> GraphState:
"""
Translates user input to English using an LLM. If the input is already in English,
it is returned as is. This ensures consistent input for downstream processing.
Args:
state (GraphState): The current graph state containing user messages.
Returns:
GraphState: The updated state with the translated input.
"""
_logger.info("Starting translation of user input to English.")
messages = state["messages"]
user_input = messages[-1][1] # Get the latest user input
# Translation prompt for the model
translation_prompt = f"""
Translate the following text to English. If the text is already in English, repeat it exactly without any additional explanation.
Text:
{user_input}
"""
# Call the OpenAI LLM to translate the text
translated_response = llm.invoke(translation_prompt)
translated_text = translated_response.content.strip() # Access the 'content' attribute and strip any extra spaces
# Update state with the translated input
state["translated_input"] = translated_text
_logger.info("Translation completed successfully. Translated input: %s", translated_text)
return state
2. 预安全检查
pre_safety_check
节点确保在用户输入中及早检测到不允许的 SQL 操作和不当内容。虽然对有害 SQL 命令(例如 CREATE
、DELETE
、DROP
、INSERT
、UPDATE
)的检查将在工作流后期,特别是在生成 SQL 查询后再次发生,但此预安全检查对于在输入阶段识别潜在问题至关重要。通过这样做,它防止了不必要的计算,并立即向用户提供了反馈。
虽然使用不允许列表来防范有害 SQL 操作提供了一种快速保护免受破坏性查询的方法,但当处理像 T-SQL 这样复杂的 SQL 后端时,维护全面的不允许列表可能变得难以管理。另一种方法是采用允许列表,将查询限制为仅允许安全操作(例如 SELECT
、JOIN
)。这种方法通过缩小允许的操作范围,而不是试图阻止所有有风险的命令,从而确保更健壮的解决方案。
要实现企业级解决方案,项目可以利用 Unity Catalog 等框架,这些框架为管理安全相关功能(例如 AI 工作流的 pre_safety_check
)提供了集中且可靠的方法。通过在此类框架内注册和管理可重用函数,您可以在所有 AI 工作流中强制执行一致且可靠的行为,从而增强安全性和可扩展性。
此外,该节点利用 LLM 分析输入是否包含冒犯性或不当内容。如果检测到不安全查询或不当内容,状态将更新错误标志并提供透明反馈,从而在早期保护工作流免受恶意或破坏性元素的侵害。
- 示例
-
禁止的操作
-
Input: "DROP TABLE customers;"
-
Response: "您的查询包含不允许的 SQL 操作,无法处理。"
-
Input: "SELECT * FROM orders;"
-
Response: "查询允许。"
-
-
不当内容
- Input: "Show me orders where customers have names like 'John the Idiot';"
- Response: "您的查询包含不当内容,无法处理。"
- Input: "Find total sales by region."
- Response: "输入安全,可以处理。"
- 代码
def pre_safety_check(state: GraphState) -> GraphState:
"""
Perform safety checks on the user input to ensure that no dangerous SQL operations
or inappropriate content is present. The function checks for SQL operations like
DELETE, DROP, and others, and also evaluates the input for toxic or unsafe content.
Args:
state (GraphState): The current graph state containing the translated user input.
Returns:
GraphState: The updated state with error status and messages if any issues are found.
"""
_logger.info("Performing safety check.")
translated_input = state["translated_input"]
messages = state["messages"]
error = "no"
# List of disallowed SQL operations (e.g., DELETE, DROP)
disallowed_operations = ['CREATE', 'DELETE', 'DROP', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
pattern = re.compile(r'\b(' + '|'.join(disallowed_operations) + r')\b', re.IGNORECASE)
# Check if the input contains disallowed SQL operations
if pattern.search(translated_input):
_logger.warning("Input contains disallowed SQL operations. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your query contains disallowed SQL operations and cannot be processed.")]
else:
# Check if the input contains inappropriate content
safety_prompt = f"""
Analyze the following input for any toxic or inappropriate content.
Respond with only "safe" or "unsafe", and nothing else.
Input:
{translated_input}
"""
safety_invoke = llm.invoke(safety_prompt)
safety_response = safety_invoke.content.strip().lower() # Access the 'content' attribute and strip any extra spaces
if safety_response == "safe":
_logger.info("Input is safe to process.")
else:
_logger.warning("Input contains inappropriate content. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your query contains inappropriate content and cannot be processed.")]
# Update state with error status and messages
state["error"] = error
state["messages"] = messages
return state
3. 模式提取
schema_extract
节点通过查询元数据动态检索数据库模式,包括表格名称和列详细信息。格式化后的模式存储在状态中,从而能够验证用户查询,同时适应当前的数据库结构。
- 示例
- Input: 请求模式提取。
模式输出- Customers(CustomerID (INTEGER), CustomerName (TEXT), ContactName (TEXT), Address (TEXT), City (TEXT), PostalCode (TEXT), Country (TEXT))
- Orders(OrderID (INTEGER), CustomerID (INTEGER), OrderDate (TEXT))
- OrderDetails(OrderDetailID (INTEGER), OrderID (INTEGER), ProductID (INTEGER), Quantity (INTEGER))
- Products(ProductID (INTEGER), ProductName (TEXT), Price (REAL))
- Input: 请求模式提取。
- 代码
def schema_extract(state: GraphState) -> GraphState:
"""
Extracts the database schema, including all tables and their respective columns,
from the connected SQLite database. This function retrieves the list of tables and
iterates through each table to gather column definitions (name and data type).
Args:
state (GraphState): The current graph state, which will be updated with the database schema.
Returns:
GraphState: The updated state with the extracted database schema.
"""
_logger.info("Extracting database schema.")
# Extract the schema from the database
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema_details = []
# Loop through each table and retrieve column information
for table_name_tuple in tables:
table_name = table_name_tuple[0]
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
# Format column definitions
column_defs = ', '.join([f"{col[1]} ({col[2]})" for col in columns])
schema_details.append(f"- {table_name}({column_defs})")
# Save the schema in the state
database_schema = '\n'.join(schema_details)
state["database_schema"] = database_schema
_logger.info(f"Database schema extracted:\n{database_schema}")
return state
4. 上下文检查
context_check
节点通过将用户查询与提取的数据库模式进行比较来验证用户查询,以确保一致性和相关性。与模式不对应的查询会被标记为不相关,从而防止资源浪费,并允许用户获得反馈以重新表述查询。
- 示例
- Input: "What is the average order value?"
Schema Match: 输入与数据库模式相关。 - Input: "Show me data from the inventory table."
Response: "您的问题与数据库无关,无法处理。"
- Input: "What is the average order value?"
- 代码
def context_check(state: GraphState) -> GraphState:
"""
Checks whether the user's input is relevant to the database schema by comparing
the user's question with the database schema. Uses a language model to determine if
the question can be answered using the provided schema.
Args:
state (GraphState): The current graph state, which contains the translated input
and the database schema.
Returns:
GraphState: The updated state with error status and messages if the input is irrelevant.
"""
_logger.info("Performing context check.")
# Extract relevant data from the state
translated_input = state["translated_input"]
messages = state["messages"]
error = "no"
database_schema = state["database_schema"] # Get the schema from the state
# Use the LLM to determine if the input is relevant to the database schema
context_prompt = f"""
Determine whether the following user input is a question that can be answered using the database schema provided below.
Respond with only "relevant" if the input is relevant to the database schema, or "irrelevant" if it is not.
User Input:
{translated_input}
Database Schema:
{database_schema}
"""
# Call the LLM for context check
llm_invoke = llm.invoke(context_prompt)
llm_response = llm_invoke.content.strip().lower() # Access the 'content' attribute and strip any extra spaces and lower case
# Process the response from the LLM
if llm_response == "relevant":
_logger.info("Input is relevant to the database schema.")
else:
_logger.info("Input is not relevant. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your question is not related to the database and cannot be processed.")]
# Update the state with error and messages
state["error"] = error
state["messages"] = messages
return state
5. 生成
generate
节点通过从向量存储中检索相关文档并利用预定义的 SQL 生成链,从自然语言输入构建 SQL 查询。它将查询与用户的意图和模式上下文对齐,并使用生成的 SQL 及其描述更新状态。
- 示例
- Input: "Find total sales."
Generated SQL: "SELECT SUM(Products.Price * OrderDetails.Quantity) AS TotalSales FROM OrderDetails LEFT JOIN Products ON OrderDetails.ProductID = Products.ProductID;" - Input: "List all customers in New York."
Generated SQL: "SELECT name FROM customers WHERE location = 'New York';"
- Input: "Find total sales."
- 代码
def generate(state: GraphState) -> GraphState:
"""
Generates an SQL query based on the user's input. The node retrieves relevant documents from
the vector store and uses a generation chain to produce an SQL query.
Args:
state (GraphState): The current graph state, which contains the translated input and
other relevant data such as messages and iteration count.
Returns:
GraphState: The updated state with the generated SQL query and related messages.
"""
_logger.info("Generating SQL query.")
# Extract relevant data from the state
messages = state["messages"]
iterations = state["iterations"]
translated_input = state["translated_input"]
database_schema = state["database_schema"]
# Retrieve relevant documents from the vector store based on the translated user input
docs = vector_store.similarity_search(translated_input, k=4)
retrieved_docs = "\n\n".join([doc.page_content for doc in docs])
# Generate the SQL query using the SQL generation chain
sql_solution = sql_gen_chain.invoke(
{
"retrieved_docs": retrieved_docs,
"database_schema": database_schema,
"messages": [("user", translated_input)],
}
)
# Save the generated SQL query in the state
messages += [
(
"assistant",
f"{sql_solution.description}\nSQL Query:\n{sql_solution.sql_code}",
)
]
iterations += 1
# Log the generated SQL query
_logger.info("Generated SQL query:\n%s", sql_solution.sql_code)
# Update the state with the generated SQL query and updated message list
state["generation"] = sql_solution
state["messages"] = messages
state["iterations"] = iterations
return state
6. 后安全检查
post_safety_check
节点通过对生成的 SQL 查询进行最终验证以检查有害 SQL 命令来确保其安全。虽然早期的预安全检查会识别用户输入中不允许的操作,但此后安全检查会验证生成后产生的 SQL 查询是否符合安全准则。这种两步方法确保即使在查询生成过程中无意中引入了不允许的操作,也可以被捕获并标记出来。如果检测到不安全的查询,节点会停止工作流,更新状态中的错误标志,并向用户提供反馈。
- 示例
- 禁止的操作
- Generated Query: "DROP TABLE orders;"
- Response: "生成的 SQL 查询包含不允许的 SQL 操作:DROP,无法处理。"
- Generated Query: "SELECT name FROM customers;"
- Response: "查询有效。"
- 代码
def post_safety_check(state: GraphState) -> GraphState:
"""
Perform safety checks on the generated SQL query to ensure that it doesn't contain disallowed operations
such as CREATE, DELETE, DROP, etc. This node checks the SQL query generated earlier in the workflow.
Args:
state (GraphState): The current graph state containing the generated SQL query.
Returns:
GraphState: The updated state with error status and messages if any issues are found.
"""
_logger.info("Performing post-safety check on the generated SQL query.")
# Retrieve the generated SQL query from the state
sql_solution = state.get("generation", {})
sql_query = sql_solution.get("sql_code", "").strip()
messages = state["messages"]
error = "no"
# List of disallowed SQL operations
disallowed_operations = ['CREATE', 'DELETE', 'DROP', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
pattern = re.compile(r'\b(' + '|'.join(disallowed_operations) + r')\b', re.IGNORECASE)
# Check if the generated SQL query contains disallowed SQL operations
found_operations = pattern.findall(sql_query)
if found_operations:
_logger.warning(
"Generated SQL query contains disallowed SQL operations: %s. Halting the workflow.",
", ".join(set(found_operations))
)
error = "yes"
messages += [("assistant", f"The generated SQL query contains disallowed SQL operations: {', '.join(set(found_operations))} and cannot be processed.")]
else:
_logger.info("Generated SQL query passed the safety check.")
# Update state with error status and messages
state["error"] = error
state["messages"] = messages
return state
7. SQL 检查
sql_check
节点通过在事务保存点内执行生成的 SQL 查询来确保其安全和语法有效。验证后会回滚所有更改,并标记错误并提供详细反馈,以保持查询完整性。
- 示例
- Input SQL: "SELECT name FROM customers WHERE city = 'New York';"
Validation: 查询有效。 - Input SQL: "SELECT MONTH(date) AS month, SUM(total) AS total_sales FROM orders GROUP BY MONTH(date);"
Response: "您的 SQL 查询执行失败:没有此函数:MONTH。"
- Input SQL: "SELECT name FROM customers WHERE city = 'New York';"
- 代码
def sql_check(state: GraphState) -> GraphState:
"""
Validates the generated SQL query by attempting to execute it on the database.
If the query is valid, the changes are rolled back to ensure no data is modified.
If there is an error during execution, the error is logged and the state is updated accordingly.
Args:
state (GraphState): The current graph state, which contains the generated SQL query
and the messages to communicate with the user.
Returns:
GraphState: The updated state with error status and messages if the query is invalid.
"""
_logger.info("Validating SQL query.")
# Extract relevant data from the state
messages = state["messages"]
sql_solution = state["generation"]
error = "no"
sql_code = sql_solution.sql_code.strip()
try:
# Start a savepoint for the transaction to allow rollback
conn.execute('SAVEPOINT sql_check;')
# Attempt to execute the SQL query
cursor.execute(sql_code)
# Roll back to the savepoint to undo any changes
conn.execute('ROLLBACK TO sql_check;')
_logger.info("SQL query validation: success.")
except Exception as e:
# Roll back in case of error
conn.execute('ROLLBACK TO sql_check;')
_logger.error("SQL query validation failed. Error: %s", e)
messages += [("user", f"Your SQL query failed to execute: {e}")]
error = "yes"
# Update the state with the error status
state["error"] = error
state["messages"] = messages
return state
8. 运行查询
run_query
节点执行已验证的 SQL 查询,连接到数据库检索结果。它使用查询输出更新状态,确保数据已格式化以便进行进一步分析或报告,同时实现了强大的错误处理。
- 示例
- Input SQL: "SELECT COUNT(*) FROM Customers WHERE City = 'New York';"
Query Result: "(0,)" - Input SQL: "SELECT SUM(Products.Price * OrderDetails.Quantity) AS TotalSales FROM OrderDetails LEFT JOIN Products ON OrderDetails.ProductID = Products.ProductID;"
Query Result: "(6925.0,)"
- Input SQL: "SELECT COUNT(*) FROM Customers WHERE City = 'New York';"
- 代码
def run_query(state: GraphState) -> GraphState:
"""
Executes the generated SQL query on the database and retrieves the results if it is a SELECT query.
For non-SELECT queries, commits the changes to the database. If no records are found for a SELECT query,
the `no_records_found` flag is set to True.
Args:
state (GraphState): The current graph state, which contains the generated SQL query and other relevant data.
Returns:
GraphState: The updated state with the query results, or a flag indicating if no records were found.
"""
_logger.info("Running SQL query.")
# Extract the SQL query from the state
sql_solution = state["generation"]
sql_code = sql_solution.sql_code.strip()
results = None
no_records_found = False # Flag to indicate no records found
try:
# Execute the SQL query
cursor.execute(sql_code)
# For SELECT queries, fetch and store results
if sql_code.upper().startswith("SELECT"):
results = cursor.fetchall()
if not results:
no_records_found = True
_logger.info("SQL query execution: success. No records found.")
else:
_logger.info("SQL query execution: success.")
else:
# For non-SELECT queries, commit the changes
conn.commit()
_logger.info("SQL query execution: success. Changes committed.")
except Exception as e:
_logger.error("SQL query execution failed. Error: %s", e)
# Update the state with results and flag for no records found
state["results"] = results
state["no_records_found"] = no_records_found
return state
决策步骤:确定下一步行动
decide_next_step
函数作为工作流中的控制点,根据当前状态决定下一步应采取什么行动。它评估 error
状态和已执行的迭代次数,以确定是应该运行查询、结束工作流,还是系统应该重试生成 SQL 查询。
-
过程
- 如果没有错误(
error == "no"
),系统将继续运行 SQL 查询。 - 如果已达到最大迭代次数(
max_iterations
),工作流将结束。 - 如果发生了错误且尚未达到最大迭代次数,系统将重试查询生成。
- 如果没有错误(
-
示例工作流决策
- 无错误,继续查询 (No Error, Proceed with Query):如果在先前的步骤中没有发现错误,工作流将继续运行查询。
- 达到最大迭代次数,结束工作流 (Maximum Iterations Reached, End Workflow):如果工作流已尝试了设定的次数(
max_iterations
),它将终止。 - 检测到错误,重试 SQL 生成 (Error Detected, Retry SQL Generation):如果发生错误且系统尚未达到重试限制,它将尝试重新生成 SQL 查询。
-
代码
def decide_next_step(state: GraphState) -> str:
"""
Determines the next step in the workflow based on the current state, including whether the query
should be run, the workflow should be finished, or if the query generation needs to be retried.
Args:
state (GraphState): The current graph state, which contains error status and iteration count.
Returns:
str: The next step in the workflow, which can be "run_query", "generate", or END.
"""
_logger.info("Deciding next step based on current state.")
error = state["error"]
iterations = state["iterations"]
if error == "no":
_logger.info("Error status: no. Proceeding with running the query.")
return "run_query"
elif iterations >= max_iterations:
_logger.info("Maximum iterations reached. Ending the workflow.")
return END
else:
_logger.info("Error detected. Retrying SQL query generation.")
return "generate"
工作流编排和条件逻辑
为了定义系统中任务的编排,我们使用 StateGraph
类构建工作流图。每个任务表示为一个节点,任务之间的转换定义为边缘。条件边缘用于根据工作流的状态控制流程。
def get_workflow(conn, cursor, vector_store):
"""Define and compile the LangGraph workflow."""
# Max iterations: defines how many times the workflow should retry in case of errors
max_iterations = 3
# SQL generation chain: this is a chain that will generate SQL based on retrieved docs
sql_gen_chain = get_sql_gen_chain()
# Initialize OpenAI LLM for translation and safety checks
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
# Define the individual nodes of the workflow
... # Insert nodes code defined above here
# Build the workflow graph
workflow = StateGraph(GraphState)
# Define workflow nodes
workflow.add_node("translate_input", translate_input) # Translate user input to structured format
workflow.add_node("pre_safety_check", pre_safety_check) # Perform a pre-safety check on input
workflow.add_node("schema_extract", schema_extract) # Extract the database schema
workflow.add_node("context_check", context_check) # Validate input relevance to context
workflow.add_node("generate", generate) # Generate SQL query
workflow.add_node("post_safety_check", post_safety_check) # Perform a post-safety check on generated SQL query
workflow.add_node("sql_check", sql_check) # Validate the generated SQL query
workflow.add_node("run_query", run_query) # Execute the SQL query
# Define workflow edges
workflow.add_edge(START, "translate_input") # Start at the translation step
workflow.add_edge("translate_input", "pre_safety_check") # Move to safety checks
# Conditional edge after safety check
workflow.add_conditional_edges(
"pre_safety_check", # Start at the pre_safety_check node
lambda state: "schema_extract" if state["error"] == "no" else END, # Decide next step
{"schema_extract": "schema_extract", END: END} # Map states to nodes
)
workflow.add_edge("schema_extract", "context_check") # Proceed to context validation
# Conditional edge after context check
workflow.add_conditional_edges(
"context_check", # Start at the context_check node
lambda state: "generate" if state["error"] == "no" else END, # Decide next step
{"generate": "generate", END: END}
)
workflow.add_edge("generate", "post_safety_check") # Proceed to post-safety check
# Conditional edge after post-safety check
workflow.add_conditional_edges(
"post_safety_check", # Start at the post_safety_check node
lambda state: "sql_check" if state["error"] == "no" else END, # If no error, proceed to sql_check, else END
{"sql_check": "sql_check", END: END},
)
# Conditional edge after SQL validation
workflow.add_conditional_edges(
"sql_check", # Start at the sql_check node
decide_next_step, # Function to determine the next step
{
"run_query": "run_query", # If SQL is valid, execute the query
"generate": "generate", # If retry is needed, go back to generation
END: END # Otherwise, terminate the workflow
}
)
workflow.add_edge("run_query", END) # Final step is to end the workflow
# Compile and return the workflow application
app = workflow.compile()
return app
- Start to
translate_input
:
工作流首先将用户输入翻译成结构化格式。 translate_input
到pre_safety_check
:
翻译后,工作流继续检查输入的安全性。pre_safety_check
条件规则:- 如果输入通过预安全检查(
state["error"] == "no"
),工作流将转移到schema_extract
。 - 如果输入未能通过预安全检查,工作流将终止(
END
)。
- 如果输入通过预安全检查(
schema_extract
到context_check
:
提取模式后,工作流将验证输入与数据库上下文的相关性。context_check
条件规则:- 如果输入相关(
state["error"] == "no"
),工作流将转移到generate
。 - 否则,工作流将终止(
END
)。
- 如果输入相关(
generate
到post_safety_check
:
工作流生成 SQL 查询并将其发送进行验证。\post_safety_check
条件规则:- 如果输入通过后安全检查(state["error"] == "no"
),工作流将转移到sql_check
。- 如果输入未能通过后安全检查,工作流将终止(END
)。sql_check
条件规则:- 如果查询有效,工作流将继续到
run_query
。 - 如果查询需要调整且迭代次数少于 3 次,工作流将循环回
generate
。 - 如果验证失败且迭代次数多于 3 次,工作流将终止(
END
)。
- 如果查询有效,工作流将继续到
run_query
到END
:
查询执行后,工作流结束。
上面的模式表示了 LangGraph 节点和边缘

在 MLflow 中记录模型
现在我们已经使用 LangGraph 构建了一个多语言查询引擎,我们准备使用 MLflow 的代码模型 (Models from Code) 来记录模型。将模型记录到 MLflow 中,可以我们将多语言查询引擎视为传统的 ML 模型,实现无缝跟踪、版本控制和打包,以便部署到不同的服务基础设施。MLflow 的代码模型策略,即记录代表模型的代码,与 MLflow 基于对象的记录方式形成对比,后者会创建一个模型对象,将其序列化并记录为 pickle 或 JSON 对象。
步骤 1:创建代码模型文件
到目前为止,我们已经在名为 workflow.py
的文件中定义了 get_workflow
函数。在此步骤中,我们将创建一个新文件 sql_model.py
,其中引入了 SQLGenerator
类。此脚本将:
- 从
workflow.py
导入get_workflow
函数。 - 将
SQLGenerator
类定义为PythonModel
,包括一个利用get_workflow
函数初始化 LangGraph 工作流的 predict 方法。 - 使用
mlflow.models.set_model
将SQLGenerator PythonModel
类指定为 MLflow 感兴趣的模型。
import mlflow
from definitions import REMOTE_SERVER_URI
from workflow import get_workflow
mlflow.set_tracking_uri(REMOTE_SERVER_URI)
class SQLGenerator(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input):
return get_workflow(
model_input["conn"], model_input["cursor"], model_input["vector_store"]
)
mlflow.models.set_model(SQLGenerator())
步骤 2:使用 MLflow 代码模型特性进行记录
定义了 SQLGenerator 自定义 Python 模型后,下一步是使用 MLflow 的代码模型特性记录它。这涉及使用 log model 标准 API,指定 sql_model.py 脚本的路径,并使用 code_paths 参数将 workflow.py 作为依赖项包含进来。这种方法确保在不同的环境或另一台机器上加载模型时,所有必需的代码文件都可用。
import mlflow
from definitions import (
EXPERIMENT_NAME,
MODEL_ALIAS,
REGISTERED_MODEL_NAME,
REMOTE_SERVER_URI,
)
from mlflow import MlflowClient
client = MlflowClient(tracking_uri=REMOTE_SERVER_URI)
mlflow.set_tracking_uri(REMOTE_SERVER_URI)
mlflow.set_experiment(EXPERIMENT_NAME)
with mlflow.start_run():
logged_model_info = mlflow.pyfunc.log_model(
python_model="sql_model.py",
artifact_path="sql_generator",
registered_model_name=REGISTERED_MODEL_NAME,
code_paths=["workflow.py"],
)
client.set_registered_model_alias(
REGISTERED_MODEL_NAME, MODEL_ALIAS, logged_model_info.registered_model_version
)
在 MLflow UI 中,存储的模型包括 sql_model.py
和 workflow.py
脚本作为运行中的 artifacts。这种从代码记录的特性不仅记录了模型的参数和指标,还捕获了定义其功能的代码。这确保了可观察性、无缝跟踪和直接通过 UI 进行的简单调试。然而,至关重要的是要确保敏感元素,例如 API 密钥或凭据,绝不硬编码到这些脚本中。由于代码按原样存储,任何包含的敏感信息都可能导致令牌泄露并带来安全风险。相反,敏感数据应使用环境变量、秘密管理系统或其他安全方法进行安全管理。
在 main.py
中使用已记录的多语言查询引擎
记录模型后,可以使用模型 URI 和标准 mlflow.pyfunc.load_model
API 从 MLflow 跟踪服务器重新加载它。加载模型时,workflow.py
脚本将与 sql_model.py
脚本一起执行,确保在调用 predict
方法时 get_workflow
函数可用。
执行以下代码,我们演示了我们的多语言查询引擎能够执行自然语言到 SQL 的生成和查询执行。
import os
import logging
import mlflow
from database import setup_database
from definitions import (
EXPERIMENT_NAME,
MODEL_ALIAS,
REGISTERED_MODEL_NAME,
REMOTE_SERVER_URI,
)
from dotenv import load_dotenv
from vector_store import setup_vector_store
mlflow.set_tracking_uri(REMOTE_SERVER_URI)
mlflow.set_experiment(EXPERIMENT_NAME)
mlflow.langchain.autolog()
# Initialize the logger
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
_logger.addHandler(handler)
def main():
# Load environment variables from .env file
load_dotenv()
# Access secrets using os.getenv
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
# Setup database and vector store
conn = setup_database()
cursor = conn.cursor()
vector_store = setup_vector_store()
# Load the model
model_uri = f"models:/{REGISTERED_MODEL_NAME}@{MODEL_ALIAS}"
model = mlflow.pyfunc.load_model(model_uri)
model_input = {"conn": conn, "cursor": cursor, "vector_store": vector_store}
app = model.predict(model_input)
# save image
app.get_graph().draw_mermaid_png(
output_file_path="sql_agent_with_safety_checks.png"
)
# Example user interaction
_logger.info("Welcome to the SQL Assistant!")
while True:
question = input("\nEnter your SQL question (or type 'exit' to quit): ")
if question.lower() == "exit":
break
# Initialize the state with all required keys
initial_state = {
"messages": [("user", question)],
"iterations": 0,
"error": "",
"results": None,
"generation": None,
"no_records_found": False,
"translated_input": "", # Initialize translated_input
}
solution = app.invoke(initial_state)
# Check if an error was set during the safety check
if solution["error"] == "yes":
_logger.info("\nAssistant Message:\n")
_logger.info(solution["messages"][-1][1]) # Display the assistant's message
continue # Skip to the next iteration
# Extract the generated SQL query from solution["generation"]
sql_query = solution["generation"].sql_code
_logger.info("\nGenerated SQL Query:\n")
_logger.info(sql_query)
# Extract and display the query results
if solution.get("no_records_found"):
_logger.info("\nNo records found matching your query.")
elif "results" in solution and solution["results"] is not None:
_logger.info("\nQuery Results:\n")
for row in solution["results"]:
_logger.info(row)
else:
_logger.info("\nNo results returned or query did not execute successfully.")
_logger.info("Goodbye!")
if __name__ == "__main__":
main()
项目文件结构
项目遵循简单的文件结构:

MLflow 跟踪
MLflow 自动化跟踪 (Automated Tracing) 为各种 GenAI 库(如 LangChain、OpenAI、LlamaIndex、DSPy 和 AutoGen)提供了完全自动化的集成。由于我们的 AI 工作流是使用 LangGraph 构建的,我们可以通过启用 mlflow.langchain.autolog()
来激活 LangChain 自动化跟踪。
启用 LangChain 自动化日志记录后,每当在链上调用 invocation API 时,跟踪信息会自动记录到当前的 MLflow 实验中。这种无缝集成确保每次交互都被捕获,以便进行监控和分析。
在 MLflow 中查看跟踪
通过导航到感兴趣的 MLflow 实验并单击“Tracing”(跟踪)选项卡,可以轻松访问跟踪信息。进入后,选择特定跟踪可以查看详细的执行信息。
每个跟踪都包含:
- 执行图 (Execution Graphs):工作流步骤的可视化。
- 输入和输出 (Inputs and Outputs):每个步骤处理的数据的详细日志。
通过利用 MLflow 跟踪,我们可以对整个图执行获得细粒度的可见性。AI 工作流图通常感觉像一个黑匣子,难以调试和理解每个步骤中发生的情况。然而,只需一行代码启用跟踪,MLflow 就能提供清晰详细的工作流洞察,使开发人员能够有效调试、监控和优化图中的每个节点,确保我们的多语言查询引擎保持透明、可审计和可扩展。
结论
在整篇文章中,我们探讨了使用 LangGraph 和 MLflow 构建和管理多语言查询引擎的过程。通过将 LangGraph 的动态 AI 工作流与 MLflow 强大的生命周期管理和跟踪特性相结合,我们创建了一个系统,它不仅提供准确高效的自然语言到 SQL 生成和执行功能,而且还具有透明性和可扩展性。