从自然语言到 SQL:构建和追踪多语言查询引擎
如果您希望构建一个结合自然语言到 SQL 生成和查询执行的多语言查询引擎,并充分利用 MLflow 的功能,那么本博客文章是您的指南。我们将探讨如何利用 **MLflow 代码中的模型 (MLflow Models from Code)** 来实现人工智能工作流的无缝跟踪和版本控制。此外,我们将深入研究 **MLflow 的跟踪 (Tracing)** 功能,该功能通过跟踪人工智能工作流中每个中间步骤的输入、输出和元数据,为这些组件带来可观测性。
引言
SQL 是管理和访问关系数据库中数据的基本技能。然而,构建复杂的 SQL 查询以回答棘手的数据问题可能既具有挑战性又耗时。这种复杂性使得有效利用数据变得困难。自然语言到 SQL (NL2SQL) 系统通过提供从自然语言到 SQL 命令的翻译来帮助解决这个问题,使非技术人员能够与数据交互:用户只需使用他们熟悉的自然语言提问,这些系统就会帮助他们返回所需的信息。
然而,在创建 NL2SQL 系统时仍然存在一些问题,例如语义模糊、模式映射或错误处理和用户反馈。因此,在构建此类系统时,我们必须设置一些护栏,而不是完全依赖 LLM。
在本博客文章中,我们将引导您完成构建多语言查询引擎的过程。该引擎支持多种语言的自然语言输入,根据翻译后的用户输入生成 SQL 查询,并执行该查询。让我们来看一个例子:使用包含有关公司客户、产品和订单信息的数据库,用户可以用任何语言提问,例如“Quantos clientes temos por país?”(葡萄牙语,意为“我们有多少客户按国家/地区划分?”)。AI 工作流将输入翻译成英文,输出“How many customers do we have per country?”。然后,它会验证输入是否安全,检查问题是否可以使用数据库模式来回答,生成适当的 SQL 查询(例如,`SELECT COUNT(CustomerID) AS NumberOfCustomers, Country FROM Customers GROUP BY Country;`),并验证查询以确保不存在有害命令(例如,DROP)。最后,它针对数据库执行查询以检索结果。
我们将首先演示如何利用 LangGraph 的功能来构建动态 AI 工作流。此工作流集成了 OpenAI 和外部数据源,例如向量存储 (Vector Store) 和 SQLite 数据库,以处理用户输入、执行安全检查、查询数据库并生成有意义的响应。
在整篇博客中,我们将利用 **MLflow 的代码中的模型 (MLflow’s Models from Code)** 功能来实现人工智能工作流的无缝跟踪和版本控制。此外,我们将深入研究 **MLflow 的跟踪 (Tracing)** 功能,该功能旨在通过跟踪与每个中间步骤相关的输入、输出和元数据,来增强人工智能工作流中各个组件的可观测性。这使得更容易识别错误和意外行为,从而提高工作流的透明度。
先决条件
要设置和运行此项目,请确保安装了以下 **Python 包**
faiss-cpulangchainlangchain-corelangchain-openailanggraphlangchain-communitypydantic >=2typing_extensionspython-dotenv
此外,需要一个 **MLflow 跟踪服务器 (MLflow Tracking Server)** 来有效地记录和管理实验、模型和跟踪。对于本地设置,请参阅官方 MLflow 文档,了解有关 配置简单的 MLflow 跟踪服务器的说明。
最后,请确保您的 OpenAI API 密钥保存在项目目录中的 .env 文件中。这使得应用程序能够安全地访问构建人工智能工作流所需的 OpenAI 服务。 .env 文件应包含如下一行
OPENAI_API_KEY=your_openai_api_key
使用 LangGraph 的多语言查询引擎
多语言查询引擎利用 LangGraph 库,这是一个 AI 编排工具,旨在为由 LLM 驱动的应用程序创建有状态、多代理和循环图架构。
与其他 AI 编排器相比,LangGraph 提供了三个核心优势:循环、可控性和持久性。它允许定义带有循环的人工智能工作流,这对于在多语言查询引擎中实现重试机制(例如 SQL 查询生成重试,如果验证失败,查询会循环返回重新生成)至关重要。这使得 LangGraph 成为构建多语言查询引擎的理想工具。
关键 LangGraph 特性:
-
**有状态架构 (Stateful Architecture)**:引擎维护图执行状态的动态快照。此快照充当节点之间的共享资源,在每次节点执行时实现高效决策和实时更新。
-
**多代理设计 (Multi-Agent Design)**:人工智能工作流在工作流中包含与 OpenAI 和其他外部工具的多次交互。
-
**循环图结构 (Cyclical Graph Structure)**:图的循环特性引入了强大的重试机制。当需要时,此机制通过循环回到先前的阶段来动态解决故障,确保图的持续执行。(此机制的详细信息将在后面讨论。)
人工智能工作流概述
多语言查询引擎的高级人工智能工作流由相互连接的节点和边组成,每个都代表一个关键阶段
-
**翻译节点 (Translation Node)**:将用户输入转换为英文。
-
**预安全检查 (Pre-safety Check)**:确保用户输入不含不良或不当内容,并且不包含有害的 SQL 命令(例如,`DELETE`、`DROP`)。
-
**数据库模式提取 (Database Schema Extraction)**:检索目标数据库的模式以了解其结构和可用数据。
-
**相关性验证 (Relevancy Validation)**:根据数据库模式验证用户输入,确保与数据库上下文保持一致。
-
**SQL 查询生成 (SQL Query Generation)**:根据用户输入和当前数据库模式生成 SQL 查询。
-
**后安全检查 (Post-safety Check)**:确保生成的 SQL 查询不包含有害的 SQL 命令(例如,`DELETE`、`DROP`)。
-
**SQL 查询验证 (SQL Query Validation)**:在回滚安全的环境中执行 SQL 查询,以确保其在运行前有效性。
-
**动态状态评估 (Dynamic State Evaluation)**:根据当前状态确定后续步骤。如果 SQL 查询验证失败,它会循环回到第 5 步重新生成查询。
-
**查询执行和结果检索 (Query Execution and Result Retrieval)**:执行 SQL 查询并返回结果(如果是 `SELECT` 语句)。
重试机制在第 8 步中引入,系统在此处动态评估当前的图状态。具体来说,当 SQL 查询验证节点(第 7 步)检测到问题时,状态会触发循环返回到 SQL 生成节点(第 5 步)以进行新的 SQL 生成尝试(最多 3 次尝试)。
组件
多语言查询引擎与多个外部组件交互,以将自然语言用户输入转换为 SQL 查询并在安全稳健的方式下执行它们。在本节中,我们将详细了解关键的人工智能工作流组件:OpenAI、向量存储 (Vector Store)、SQLite 数据库和 SQL 生成链 (SQL Generation Chain)。
OpenAI
OpenAI,更具体地说是 `gpt-4o-mini` 语言模型,在工作流的多个阶段起着至关重要的作用。它提供了所需的情报来完成
-
**翻译**:将用户输入翻译成英文。如果文本已经是英文,它会简单地重复输入。
-
**安全检查**:分析用户输入,确保其中不包含不良或不当内容。
-
**相关性检查**:评估用户的问题相对于数据库模式是否相关。
-
**SQL 生成**:根据用户输入、SQL 生成文档和数据库模式生成有效且可执行的 SQL 查询。
关于 OpenAI 实现的详细信息将在后面的 **节点描述 (Node Descriptions)** 部分提供。
FAISS 向量存储 (FAISS Vector Store)
为了构建一个有效的自然语言到 SQL 引擎,使其能够生成准确且可执行的 SQL 查询,我们利用 Langchain 的 FAISS 向量存储功能。这种设置允许系统从先前存储在向量数据库中的 **W3Schools SQL 文档**中搜索和提取 SQL 查询生成指南,从而提高 SQL 查询生成的成功率。
为演示目的,我们使用的是 FAISS,它是一个内存中的向量存储,向量直接存储在 RAM 中。这提供了快速访问,但意味着数据在运行之间不会持久化。对于更具可扩展性的解决方案,该方案可以实现跨多个项目存储和共享嵌入,我们推荐使用替代方案,例如 AWS OpenSearch、Vertex AI 向量搜索、Azure 向量搜索 或 Mosaic AI 向量搜索。这些基于云的解决方案提供持久化存储、自动扩展以及与其他云服务的无缝集成,非常适合大规模应用。
步骤 1:加载 SQL 文档
创建带有 SQL 查询生成指南的 FAISS 向量存储的第一步是使用 LangChain 的 `RecursiveUrlLoader` 从 **W3Schools SQL 页面**加载 SQL 文档。此工具会检索文档,允许我们将其用作引擎的知识库。
步骤 2:将文本分割成可管理的块
加载的 SQL 文档是冗长的文本,难以被 LLM 有效地摄取。为解决此问题,下一步是使用 Langchain 的 `RecursiveCharacterTextSplitter` 将文本分割成更小、可管理的块。通过将文本分割成 500 个字符的块并具有 50 个字符的重叠,我们确保语言模型具有足够的上下文,同时最大限度地减少跨块丢失重要信息的风险。`split_text` 方法应用此分割过程,将结果片段存储在一个名为 'documents' 的列表中。
步骤 3:生成嵌入模型
第三步是创建一个模型,将这些块转换为嵌入(每个文本块的矢量化数值表示)。嵌入使系统能够比较块与用户输入之间的相似性,从而有助于检索最相关的匹配项以进行 SQL 查询生成。
步骤 4:在 FAISS 向量存储中创建和存储嵌入
最后,我们使用 FAISS 创建并存储嵌入。`FAISS.from_texts` 方法接收所有块,计算它们的嵌入,并将它们存储在一个高速可搜索的向量数据库中。这个可搜索的数据库允许引擎有效地检索相关的 SQL 指南,显著提高了可执行 SQL 查询生成的成功率。
import logging
import os
from bs4 import BeautifulSoup as Soup
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
def setup_vector_store(logger: logging.Logger):
"""Setup or load the vector store."""
if not os.path.exists("data"):
os.makedirs("data")
vector_store_dir = "data/vector_store"
if os.path.exists(vector_store_dir):
# Load the vector store from disk
logger.info("Loading vector store from disk...")
vector_store = FAISS.load_local(
vector_store_dir,
OpenAIEmbeddings(),
allow_dangerous_deserialization=True,
)
else:
logger.info("Creating new vector store...")
# Load SQL documentation
url = "https://w3schools.org.cn/sql/"
loader = RecursiveUrlLoader(
url=url, max_depth=2, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()
# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
)
documents = []
for doc in docs:
splits = text_splitter.split_text(doc.page_content)
for i, split in enumerate(splits):
documents.append(
{
"content": split,
"metadata": {"source": doc.metadata["source"], "chunk": i},
}
)
# Compute embeddings and create vector store
embedding_model = OpenAIEmbeddings()
vector_store = FAISS.from_texts(
[doc["content"] for doc in documents],
embedding_model,
metadatas=[doc["metadata"] for doc in documents],
)
# Save the vector store to disk
vector_store.save_local(vector_store_dir)
logger.info("Vector store created and saved to disk.")
return vector_store
SQLite 数据库
SQLite 数据库是多语言查询引擎的关键组成部分,用作结构化数据存储库。SQLite 提供了一个轻量级、快速且自包含的关系数据库引擎,无需服务器设置或安装。其紧凑的大小(小于 500KB)和零配置特性使其易于使用,而其平台无关的数据库格式确保了跨不同系统的无缝可移植性。作为一个本地磁盘数据库,SQLite 是避免设置 MySQL 或 PostgreSQL 复杂性的理想选择,同时仍提供一个可靠的、功能齐全的 SQL 引擎,具有出色的性能。
SQLite 数据库通过以下方式支持高效的 SQL 查询生成、验证和执行
-
**模式提取**:为用户的输入上下文验证(第 4 步)和可执行的 SQL 查询生成(第 5 步)提供模式信息。
-
**查询执行**:在验证阶段(第 7 步)和查询执行阶段(第 9 步)在回滚安全的环境中执行 SQL 查询,为 `SELECT` 语句获取结果,并为其他查询类型提交更改。
SQLite 数据库初始化
当初始化 AI 工作流时,使用 `setup_database` 函数来初始化数据库。此过程包括
-
**设置 SQLite 数据库连接**:建立到 SQLite 数据库的连接,启用数据交互。
-
**表创建**:为 AI 工作流定义和创建必要的数据库表。
-
**数据填充**:用示例数据填充表,以支持查询执行和验证阶段。
import logging
import os
import sqlite3
def create_connection(db_file="data/database.db"):
"""Create a database connection to the SQLite database."""
conn = sqlite3.connect(db_file)
return conn
def create_tables(conn):
"""Create tables in the database."""
cursor = conn.cursor()
# Create Customers table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Customers (
CustomerID INTEGER PRIMARY KEY,
CustomerName TEXT,
ContactName TEXT,
Address TEXT,
City TEXT,
PostalCode TEXT,
Country TEXT
)
"""
)
# Create Orders table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Orders (
OrderID INTEGER PRIMARY KEY,
CustomerID INTEGER,
OrderDate TEXT,
FOREIGN KEY (CustomerID) REFERENCES Customers (CustomerID)
)
"""
)
# Create OrderDetails table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS OrderDetails (
OrderDetailID INTEGER PRIMARY KEY,
OrderID INTEGER,
ProductID INTEGER,
Quantity INTEGER,
FOREIGN KEY (OrderID) REFERENCES Orders (OrderID),
FOREIGN KEY (ProductID) REFERENCES Products (ProductID)
)
"""
)
# Create Products table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Products (
ProductID INTEGER PRIMARY KEY,
ProductName TEXT,
Price REAL
)
"""
)
conn.commit()
def populate_tables(conn):
"""Populate tables with sample data if they are empty."""
cursor = conn.cursor()
# Populate Customers table if empty
cursor.execute("SELECT COUNT(*) FROM Customers")
if cursor.fetchone()[0] == 0:
customers = []
for i in range(1, 51):
customers.append(
(
i,
f"Customer {i}",
f"Contact {i}",
f"Address {i}",
f"City {i % 10}",
f"{10000 + i}",
f"Country {i % 5}",
)
)
cursor.executemany(
"""
INSERT INTO Customers (CustomerID, CustomerName, ContactName, Address, City, PostalCode, Country)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
customers,
)
# Populate Products table if empty
cursor.execute("SELECT COUNT(*) FROM Products")
if cursor.fetchone()[0] == 0:
products = []
for i in range(1, 51):
products.append((i, f"Product {i}", round(10 + i * 0.5, 2)))
cursor.executemany(
"""
INSERT INTO Products (ProductID, ProductName, Price)
VALUES (?, ?, ?)
""",
products,
)
# Populate Orders table if empty
cursor.execute("SELECT COUNT(*) FROM Orders")
if cursor.fetchone()[0] == 0:
orders = []
from datetime import datetime, timedelta
base_date = datetime(2023, 1, 1)
for i in range(1, 51):
order_date = base_date + timedelta(days=i)
orders.append(
(
i,
i % 50 + 1, # CustomerID between 1 and 50
order_date.strftime("%Y-%m-%d"),
)
)
cursor.executemany(
"""
INSERT INTO Orders (OrderID, CustomerID, OrderDate)
VALUES (?, ?, ?)
""",
orders,
)
# Populate OrderDetails table if empty
cursor.execute("SELECT COUNT(*) FROM OrderDetails")
if cursor.fetchone()[0] == 0:
order_details = []
for i in range(1, 51):
order_details.append(
(
i,
i % 50 + 1, # OrderID between 1 and 50
i % 50 + 1, # ProductID between 1 and 50
(i % 5 + 1) * 2, # Quantity between 2 and 10
)
)
cursor.executemany(
"""
INSERT INTO OrderDetails (OrderDetailID, OrderID, ProductID, Quantity)
VALUES (?, ?, ?, ?)
""",
order_details,
)
conn.commit()
def setup_database(logger: logging.Logger):
"""Setup the database and return the connection."""
db_file = "data/database.db"
if not os.path.exists("data"):
os.makedirs("data")
db_exists = os.path.exists(db_file)
conn = create_connection(db_file)
if not db_exists:
logger.info("Setting up the database...")
create_tables(conn)
populate_tables(conn)
else:
logger.info("Database already exists. Skipping setup.")
return conn
SQL 生成链 (SQL Generation Chain)
**SQL 生成链** (`sql_gen_chain`) 是我们工作流中自动化 SQL 查询生成的骨干。此链利用 LangChain 的模块化功能和 OpenAI 的高级自然语言处理能力,将用户问题转化为精确且可执行的 SQL 查询。
核心特性:
-
**提示驱动生成 (Prompt-Driven Generation)**:从经过深思熟虑设计的提示开始,该提示集成了数据库模式和文档片段,确保查询在上下文上准确。
-
**结构化响应 (Structured Responses)**:以预定义格式提供输出,包括
-
查询目的的**描述**。
-
随时可执行的相应**SQL 代码**。
-
-
**适应性强且可靠**:使用 `gpt-4o-mini` 进行稳健、一致的查询生成,最大限度地减少手动工作和错误。
此链是我们工作流中的一个关键组件,它实现了 SQL 查询生成与下游流程的无缝集成,确保了准确性,并显著提高了效率。
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
class SQLQuery(BaseModel):
"""Schema for SQL query solutions to questions."""
description: str = Field(description="Description of the SQL query")
sql_code: str = Field(description="The SQL code block")
def get_sql_gen_chain():
"""Set up the SQL generation chain."""
sql_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a SQL assistant with expertise in SQL query generation. \n
Answer the user's question based on the provided documentation snippets and the database schema provided below. Ensure any SQL query you provide is valid and executable. \n
Structure your answer with a description of the query, followed by the SQL code block. Here are the documentation snippets:\n{retrieved_docs}\n\nDatabase Schema:\n{database_schema}""",
),
("placeholder", "{messages}"),
]
)
# Initialize the OpenAI LLM
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
# Create the code generation chain
sql_gen_chain = sql_gen_prompt | llm.with_structured_output(SQLQuery)
return sql_gen_chain
工作流设置和初始化
在深入研究工作流节点之前,设置必要的组件并定义工作流的结构至关重要。本节将介绍基本库、日志记录和自定义 `GraphState` 类的初始化,以及主要的工作流编译函数。
定义 `GraphState`
`GraphState` 类是一个自定义的 `TypedDict`,用于在工作流进行时维护状态信息。它充当节点之间的共享数据结构,确保连续性和一致性。关键字段包括
- **`error`**:跟踪是否发生错误。
- **`messages`**:存储用户和系统消息的列表。
- **`generation`**:保存生成的 SQL 查询。
- **`iterations`**:在发生错误时跟踪重试尝试的次数。
- **`results`**:存储 SQL 执行结果(如果有)。
- **`no_records_found`**:如果在查询中未返回任何记录,则设置标志。
- **`translated_input`**:包含用户的翻译输入。
- **`database_schema`**:维护数据库模式以进行上下文验证。
import logging
import re
from typing import List, Optional
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from sql_generation import get_sql_gen_chain
from typing_extensions import TypedDict
# Initialize the logger
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
_logger.addHandler(handler)
class GraphState(TypedDict):
error: str # Tracks if an error has occurred
messages: List # List of messages (user input and assistant messages)
generation: Optional[str] # Holds the generated SQL query
iterations: int # Keeps track of how many times the workflow has retried
results: Optional[List] # Holds the results of SQL execution
no_records_found: bool # Flag for whether any records were found in the SQL result
translated_input: str # Holds the translated user input
database_schema: str # Holds the extracted database schema for context checking
工作流编译函数
主函数 `get_workflow` 负责定义和编译工作流。关键组件包括
- **`conn` 和 `cursor`**:用于数据库连接和查询执行。
- **`vector_store`**:用于上下文检索的向量数据库。
- **`max_iterations`**:设置重试尝试的限制,以防止无限循环。
- **`sql_gen_chain`**:从 `sql_generation` 检索 SQL 生成链,以根据上下文输入生成 SQL 查询。
- **`ChatOpenAI`**:初始化 OpenAI `gpt-4o-mini` 模型,用于安全检查和查询翻译等任务。
def get_workflow(conn, cursor, vector_store):
"""Define and compile the LangGraph workflow."""
# Max iterations: defines how many times the workflow should retry in case of errors
max_iterations = 3
# SQL generation chain: this is a chain that will generate SQL based on retrieved docs
sql_gen_chain = get_sql_gen_chain()
# Initialize OpenAI LLM for translation and safety checks
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
# Define the individual nodes of the workflow
此函数是使用 `StateGraph` 创建完整工作流的入口点。工作流中各个节点的定义和连接将在后续部分进行。
节点描述
1. 翻译输入
`translate_input` 节点将用户查询翻译成英文,以标准化处理并确保与下游节点的兼容性。将用户输入翻译为 AI 工作流的第一步,确保了任务的隔离,并提高了可观测性。任务隔离通过将翻译与其他下游任务(如用户输入安全验证和 SQL 生成)隔离开来,简化了工作流。提高的可观测性在 MLflow 中提供了清晰的跟踪记录,使调试和监控过程更容易。
- 示例
- 输入:“Quantos pedidos foram realizados em Novembro?”
- 翻译后:“How many orders were made in November?”
- 输入:“Combien de ventes avons-nous enregistrées en France ?”
- 翻译后:“How many sales did we record in France?”
- 代码
def translate_input(state: GraphState) -> GraphState:
"""
Translates user input to English using an LLM. If the input is already in English,
it is returned as is. This ensures consistent input for downstream processing.
Args:
state (GraphState): The current graph state containing user messages.
Returns:
GraphState: The updated state with the translated input.
"""
_logger.info("Starting translation of user input to English.")
messages = state["messages"]
user_input = messages[-1][1] # Get the latest user input
# Translation prompt for the model
translation_prompt = f"""
Translate the following text to English. If the text is already in English, repeat it exactly without any additional explanation.
Text:
{user_input}
"""
# Call the OpenAI LLM to translate the text
translated_response = llm.invoke(translation_prompt)
translated_text = translated_response.content.strip() # Access the 'content' attribute and strip any extra spaces
# Update state with the translated input
state["translated_input"] = translated_text
_logger.info("Translation completed successfully. Translated input: %s", translated_text)
return state
2. 预安全检查 (Pre-safety Check)
`pre_safety_check` 节点确保及早检测到用户输入中不允许的 SQL 操作和不当内容。虽然对有害 SQL 命令(例如 `CREATE`、`DELETE`、`DROP`、`INSERT`、`UPDATE`)的检查稍后会在工作流中再次发生,特别是在生成 SQL 查询之后,但此预安全检查对于在输入阶段识别潜在问题至关重要。通过这样做,它可以防止不必要的计算,并向用户提供即时反馈。
虽然使用禁止列表来防止有害的 SQL 操作提供了一种快速防范破坏性查询的方法,但当处理复杂的 SQL 后端(如 T-SQL)时,维护全面的禁止列表可能会变得难以管理。另一种方法是采用允许列表,将查询限制为仅允许安全操作(例如 `SELECT`、`JOIN`)。这种方法通过缩小允许的操作范围而不是尝试阻止每个有风险的命令,确保了更稳健的解决方案。
为实现企业级解决方案,项目可以利用像 Unity Catalog 这样的框架,该框架提供了一种集中且稳健的方法来管理与安全相关的函数,例如用于 AI 工作流的 `pre_safety_check`。通过在此类框架中注册和管理可重用函数,您可以跨所有 AI 工作流执行一致且可靠的行为,从而增强安全性和可扩展性。
此外,该节点利用 LLM 分析输入是否存在冒犯性或不当内容。如果检测到不安全的查询或不当内容,系统会使用错误标志更新状态,并提供透明的反馈,从而尽早保护工作流免受恶意或破坏性因素的侵害。
- 示例
-
禁止的操作
-
**输入:** “DROP TABLE customers;”
-
**响应:** “您的查询包含禁止的 SQL 操作,无法处理。”
-
**输入:** _“SELECT _ FROM orders;”*
-
**响应:** “查询允许。”
-
-
不当内容
- **输入:** “Show me orders where customers have names like 'John the Idiot';”
- **响应:** “您的查询包含不当内容,无法处理。”
- **输入:** “Find total sales by region.”
- **响应:** “输入安全,可以处理。”
- 代码
def pre_safety_check(state: GraphState) -> GraphState:
"""
Perform safety checks on the user input to ensure that no dangerous SQL operations
or inappropriate content is present. The function checks for SQL operations like
DELETE, DROP, and others, and also evaluates the input for toxic or unsafe content.
Args:
state (GraphState): The current graph state containing the translated user input.
Returns:
GraphState: The updated state with error status and messages if any issues are found.
"""
_logger.info("Performing safety check.")
translated_input = state["translated_input"]
messages = state["messages"]
error = "no"
# List of disallowed SQL operations (e.g., DELETE, DROP)
disallowed_operations = ['CREATE', 'DELETE', 'DROP', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
pattern = re.compile(r'\b(' + '|'.join(disallowed_operations) + r')\b', re.IGNORECASE)
# Check if the input contains disallowed SQL operations
if pattern.search(translated_input):
_logger.warning("Input contains disallowed SQL operations. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your query contains disallowed SQL operations and cannot be processed.")]
else:
# Check if the input contains inappropriate content
safety_prompt = f"""
Analyze the following input for any toxic or inappropriate content.
Respond with only "safe" or "unsafe", and nothing else.
Input:
{translated_input}
"""
safety_invoke = llm.invoke(safety_prompt)
safety_response = safety_invoke.content.strip().lower() # Access the 'content' attribute and strip any extra spaces
if safety_response == "safe":
_logger.info("Input is safe to process.")
else:
_logger.warning("Input contains inappropriate content. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your query contains inappropriate content and cannot be processed.")]
# Update state with error status and messages
state["error"] = error
state["messages"] = messages
return state
3. 模式提取 (Schema Extract)
`schema_extract` 节点通过查询元数据动态检索数据库模式,包括表名和列详细信息。格式化的模式存储在状态中,允许验证用户查询,同时适应当前的数据库结构。
- 示例
- 输入:请求模式提取。
模式输出- Customers(CustomerID (INTEGER), CustomerName (TEXT), ContactName (TEXT), Address (TEXT), City (TEXT), PostalCode (TEXT), Country (TEXT))
- Orders(OrderID (INTEGER), CustomerID (INTEGER), OrderDate (TEXT))
- OrderDetails(OrderDetailID (INTEGER), OrderID (INTEGER), ProductID (INTEGER), Quantity (INTEGER))
- Products(ProductID (INTEGER), ProductName (TEXT), Price (REAL))
- 输入:请求模式提取。
- 代码
def schema_extract(state: GraphState) -> GraphState:
"""
Extracts the database schema, including all tables and their respective columns,
from the connected SQLite database. This function retrieves the list of tables and
iterates through each table to gather column definitions (name and data type).
Args:
state (GraphState): The current graph state, which will be updated with the database schema.
Returns:
GraphState: The updated state with the extracted database schema.
"""
_logger.info("Extracting database schema.")
# Extract the schema from the database
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema_details = []
# Loop through each table and retrieve column information
for table_name_tuple in tables:
table_name = table_name_tuple[0]
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
# Format column definitions
column_defs = ', '.join([f"{col[1]} ({col[2]})" for col in columns])
schema_details.append(f"- {table_name}({column_defs})")
# Save the schema in the state
database_schema = '\n'.join(schema_details)
state["database_schema"] = database_schema
_logger.info(f"Database schema extracted:\n{database_schema}")
return state
4. 上下文检查 (Context Check)
`context_check` 节点通过将用户查询与提取的数据库模式进行比较,验证用户查询的相关性,以确保一致性和相关性。与模式不对应的查询将被标记为不相关,从而防止资源浪费并使用户能够通过用户反馈重新表述查询。
- 示例
- 输入:“What is the average order value?”
模式匹配:输入与数据库模式相关。 - 输入:“Show me data from the inventory table.”
响应:“您的问题与数据库无关,无法处理。”
- 输入:“What is the average order value?”
- 代码
def context_check(state: GraphState) -> GraphState:
"""
Checks whether the user's input is relevant to the database schema by comparing
the user's question with the database schema. Uses a language model to determine if
the question can be answered using the provided schema.
Args:
state (GraphState): The current graph state, which contains the translated input
and the database schema.
Returns:
GraphState: The updated state with error status and messages if the input is irrelevant.
"""
_logger.info("Performing context check.")
# Extract relevant data from the state
translated_input = state["translated_input"]
messages = state["messages"]
error = "no"
database_schema = state["database_schema"] # Get the schema from the state
# Use the LLM to determine if the input is relevant to the database schema
context_prompt = f"""
Determine whether the following user input is a question that can be answered using the database schema provided below.
Respond with only "relevant" if the input is relevant to the database schema, or "irrelevant" if it is not.
User Input:
{translated_input}
Database Schema:
{database_schema}
"""
# Call the LLM for context check
llm_invoke = llm.invoke(context_prompt)
llm_response = llm_invoke.content.strip().lower() # Access the 'content' attribute and strip any extra spaces and lower case
# Process the response from the LLM
if llm_response == "relevant":
_logger.info("Input is relevant to the database schema.")
else:
_logger.info("Input is not relevant. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your question is not related to the database and cannot be processed.")]
# Update the state with error and messages
state["error"] = error
state["messages"] = messages
return state
5. 生成 (Generate)
`generate` 节点通过从向量存储中检索相关文档并利用预定义的 SQL 生成链,从自然语言输入构建 SQL 查询。它将查询与用户的意图和模式上下文对齐,并用生成的 SQL 及其描述更新状态。
- 示例
- 输入:“Find total sales.”
生成的 SQL:“SELECT SUM(Products.Price * OrderDetails.Quantity) AS TotalSales FROM OrderDetails LEFT JOIN Products ON OrderDetails.ProductID = Products.ProductID;” - 输入:“List all customers in New York.”
生成的 SQL:“SELECT name FROM customers WHERE location = 'New York';”
- 输入:“Find total sales.”
- 代码
def generate(state: GraphState) -> GraphState:
"""
Generates an SQL query based on the user's input. The node retrieves relevant documents from
the vector store and uses a generation chain to produce an SQL query.
Args:
state (GraphState): The current graph state, which contains the translated input and
other relevant data such as messages and iteration count.
Returns:
GraphState: The updated state with the generated SQL query and related messages.
"""
_logger.info("Generating SQL query.")
# Extract relevant data from the state
messages = state["messages"]
iterations = state["iterations"]
translated_input = state["translated_input"]
database_schema = state["database_schema"]
# Retrieve relevant documents from the vector store based on the translated user input
docs = vector_store.similarity_search(translated_input, k=4)
retrieved_docs = "\n\n".join([doc.page_content for doc in docs])
# Generate the SQL query using the SQL generation chain
sql_solution = sql_gen_chain.invoke(
{
"retrieved_docs": retrieved_docs,
"database_schema": database_schema,
"messages": [("user", translated_input)],
}
)
# Save the generated SQL query in the state
messages += [
(
"assistant",
f"{sql_solution.description}\nSQL Query:\n{sql_solution.sql_code}",
)
]
iterations += 1
# Log the generated SQL query
_logger.info("Generated SQL query:\n%s", sql_solution.sql_code)
# Update the state with the generated SQL query and updated message list
state["generation"] = sql_solution
state["messages"] = messages
state["iterations"] = iterations
return state
6. 后安全检查 (Post-safety Check)
`post_safety_check` 节点通过对有害 SQL 命令进行最终验证,确保生成的 SQL 查询是安全的。虽然早期的预安全检查会识别用户输入中禁止的操作,但此后安全检查会验证生成后产生的 SQL 查询是否符合安全准则。这种两步方法确保即使在查询生成过程中无意中引入了禁止的操作,也可以捕获并标记它们。如果检测到不安全的查询,该节点将停止工作流,用错误标志更新状态,并向用户提供反馈。
- 示例
- 禁止的操作
- **生成的查询:** “DROP TABLE orders;”
- **响应:** “生成的 SQL 查询包含禁止的 SQL 操作:DROP,无法处理。”
- **生成的查询:** “SELECT name FROM customers;”
- **响应:** “查询有效。”
- 代码
def post_safety_check(state: GraphState) -> GraphState:
"""
Perform safety checks on the generated SQL query to ensure that it doesn't contain disallowed operations
such as CREATE, DELETE, DROP, etc. This node checks the SQL query generated earlier in the workflow.
Args:
state (GraphState): The current graph state containing the generated SQL query.
Returns:
GraphState: The updated state with error status and messages if any issues are found.
"""
_logger.info("Performing post-safety check on the generated SQL query.")
# Retrieve the generated SQL query from the state
sql_solution = state.get("generation", {})
sql_query = sql_solution.get("sql_code", "").strip()
messages = state["messages"]
error = "no"
# List of disallowed SQL operations
disallowed_operations = ['CREATE', 'DELETE', 'DROP', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
pattern = re.compile(r'\b(' + '|'.join(disallowed_operations) + r')\b', re.IGNORECASE)
# Check if the generated SQL query contains disallowed SQL operations
found_operations = pattern.findall(sql_query)
if found_operations:
_logger.warning(
"Generated SQL query contains disallowed SQL operations: %s. Halting the workflow.",
", ".join(set(found_operations))
)
error = "yes"
messages += [("assistant", f"The generated SQL query contains disallowed SQL operations: {', '.join(set(found_operations))} and cannot be processed.")]
else:
_logger.info("Generated SQL query passed the safety check.")
# Update state with error status and messages
state["error"] = error
state["messages"] = messages
return state
7. SQL 检查 (SQL Check)
`sql_check` 节点通过在事务性保存点内执行生成的 SQL 查询,确保其安全和语法有效。验证后,所有更改都会回滚,并标记错误并提供详细反馈,以维护查询的完整性。
- 示例
- 输入 SQL: “SELECT name FROM customers WHERE city = 'New York';”
验证:查询有效。 - 输入 SQL: “SELECT MONTH(date) AS month, SUM(total) AS total_sales FROM orders GROUP BY MONTH(date);”
响应:“您的 SQL 查询执行失败:不存在函数 MONTH。”
- 输入 SQL: “SELECT name FROM customers WHERE city = 'New York';”
- 代码
def sql_check(state: GraphState) -> GraphState:
"""
Validates the generated SQL query by attempting to execute it on the database.
If the query is valid, the changes are rolled back to ensure no data is modified.
If there is an error during execution, the error is logged and the state is updated accordingly.
Args:
state (GraphState): The current graph state, which contains the generated SQL query
and the messages to communicate with the user.
Returns:
GraphState: The updated state with error status and messages if the query is invalid.
"""
_logger.info("Validating SQL query.")
# Extract relevant data from the state
messages = state["messages"]
sql_solution = state["generation"]
error = "no"
sql_code = sql_solution.sql_code.strip()
try:
# Start a savepoint for the transaction to allow rollback
conn.execute('SAVEPOINT sql_check;')
# Attempt to execute the SQL query
cursor.execute(sql_code)
# Roll back to the savepoint to undo any changes
conn.execute('ROLLBACK TO sql_check;')
_logger.info("SQL query validation: success.")
except Exception as e:
# Roll back in case of error
conn.execute('ROLLBACK TO sql_check;')
_logger.error("SQL query validation failed. Error: %s", e)
messages += [("user", f"Your SQL query failed to execute: {e}")]
error = "yes"
# Update the state with the error status
state["error"] = error
state["messages"] = messages
return state
8. 运行查询 (Run Query)
`run_query` 节点执行经过验证的 SQL 查询,连接到数据库以检索结果。它使用查询输出更新状态,确保数据格式化以供进一步分析或报告,同时实施稳健的错误处理。
- 示例
- 输入 SQL: “SELECT COUNT(*) FROM Customers WHERE City = 'New York';”
查询结果: “(0,)” - 输入 SQL: _“SELECT SUM(Products.Price * OrderDetails.Quantity) AS TotalSales FROM OrderDetails LEFT JOIN Products ON OrderDetails.ProductID = Products.ProductID;”*
查询结果: _“(6925.0,)”_
- 输入 SQL: “SELECT COUNT(*) FROM Customers WHERE City = 'New York';”
- 代码
def run_query(state: GraphState) -> GraphState:
"""
Executes the generated SQL query on the database and retrieves the results if it is a SELECT query.
For non-SELECT queries, commits the changes to the database. If no records are found for a SELECT query,
the `no_records_found` flag is set to True.
Args:
state (GraphState): The current graph state, which contains the generated SQL query and other relevant data.
Returns:
GraphState: The updated state with the query results, or a flag indicating if no records were found.
"""
_logger.info("Running SQL query.")
# Extract the SQL query from the state
sql_solution = state["generation"]
sql_code = sql_solution.sql_code.strip()
results = None
no_records_found = False # Flag to indicate no records found
try:
# Execute the SQL query
cursor.execute(sql_code)
# For SELECT queries, fetch and store results
if sql_code.upper().startswith("SELECT"):
results = cursor.fetchall()
if not results:
no_records_found = True
_logger.info("SQL query execution: success. No records found.")
else:
_logger.info("SQL query execution: success.")
else:
# For non-SELECT queries, commit the changes
conn.commit()
_logger.info("SQL query execution: success. Changes committed.")
except Exception as e:
_logger.error("SQL query execution failed. Error: %s", e)
# Update the state with results and flag for no records found
state["results"] = results
state["no_records_found"] = no_records_found
return state
决策步骤:确定下一步操作
`decide_next_step` 函数充当工作流中的控制点,根据当前状态决定下一步应采取的操作。它会评估 `error` 状态和迄今执行的迭代次数,以确定是应该运行查询、完成工作流,还是系统应该重试生成 SQL 查询。
-
过程
- 如果没有错误(`error == "no"`),系统将继续运行 SQL 查询。
- 如果已达到最大迭代次数(`max_iterations`),工作流结束。
- 如果发生错误且尚未达到最大迭代次数,系统将重试查询生成。
-
工作流决策示例
- **无错误,继续查询**:如果在前几个步骤中未发现错误,工作流将继续运行查询。
- **达到最大迭代次数,结束工作流**:如果工作流已经尝试了一定次数(`max_iterations`),它将终止。
- **检测到错误,重试 SQL 生成**:如果发生错误且系统尚未达到重试限制,它将尝试重新生成 SQL 查询。
-
代码
def decide_next_step(state: GraphState) -> str:
"""
Determines the next step in the workflow based on the current state, including whether the query
should be run, the workflow should be finished, or if the query generation needs to be retried.
Args:
state (GraphState): The current graph state, which contains error status and iteration count.
Returns:
str: The next step in the workflow, which can be "run_query", "generate", or END.
"""
_logger.info("Deciding next step based on current state.")
error = state["error"]
iterations = state["iterations"]
if error == "no":
_logger.info("Error status: no. Proceeding with running the query.")
return "run_query"
elif iterations >= max_iterations:
_logger.info("Maximum iterations reached. Ending the workflow.")
return END
else:
_logger.info("Error detected. Retrying SQL query generation.")
return "generate"
工作流编排和条件逻辑
为了定义系统中任务的编排,我们使用 `StateGraph` 类构建一个工作流图。每个任务表示为一个节点,任务之间的转换定义为边。使用条件边根据工作流的状态来控制流程。
def get_workflow(conn, cursor, vector_store):
"""Define and compile the LangGraph workflow."""
# Max iterations: defines how many times the workflow should retry in case of errors
max_iterations = 3
# SQL generation chain: this is a chain that will generate SQL based on retrieved docs
sql_gen_chain = get_sql_gen_chain()
# Initialize OpenAI LLM for translation and safety checks
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
# Define the individual nodes of the workflow
... # Insert nodes code defined above here
# Build the workflow graph
workflow = StateGraph(GraphState)
# Define workflow nodes
workflow.add_node("translate_input", translate_input) # Translate user input to structured format
workflow.add_node("pre_safety_check", pre_safety_check) # Perform a pre-safety check on input
workflow.add_node("schema_extract", schema_extract) # Extract the database schema
workflow.add_node("context_check", context_check) # Validate input relevance to context
workflow.add_node("generate", generate) # Generate SQL query
workflow.add_node("post_safety_check", post_safety_check) # Perform a post-safety check on generated SQL query
workflow.add_node("sql_check", sql_check) # Validate the generated SQL query
workflow.add_node("run_query", run_query) # Execute the SQL query
# Define workflow edges
workflow.add_edge(START, "translate_input") # Start at the translation step
workflow.add_edge("translate_input", "pre_safety_check") # Move to safety checks
# Conditional edge after safety check
workflow.add_conditional_edges(
"pre_safety_check", # Start at the pre_safety_check node
lambda state: "schema_extract" if state["error"] == "no" else END, # Decide next step
{"schema_extract": "schema_extract", END: END} # Map states to nodes
)
workflow.add_edge("schema_extract", "context_check") # Proceed to context validation
# Conditional edge after context check
workflow.add_conditional_edges(
"context_check", # Start at the context_check node
lambda state: "generate" if state["error"] == "no" else END, # Decide next step
{"generate": "generate", END: END}
)
workflow.add_edge("generate", "post_safety_check") # Proceed to post-safety check
# Conditional edge after post-safety check
workflow.add_conditional_edges(
"post_safety_check", # Start at the post_safety_check node
lambda state: "sql_check" if state["error"] == "no" else END, # If no error, proceed to sql_check, else END
{"sql_check": "sql_check", END: END},
)
# Conditional edge after SQL validation
workflow.add_conditional_edges(
"sql_check", # Start at the sql_check node
decide_next_step, # Function to determine the next step
{
"run_query": "run_query", # If SQL is valid, execute the query
"generate": "generate", # If retry is needed, go back to generation
END: END # Otherwise, terminate the workflow
}
)
workflow.add_edge("run_query", END) # Final step is to end the workflow
# Compile and return the workflow application
app = workflow.compile()
return app
- 从开始到 `translate_input`:
工作流首先将用户输入翻译成结构化格式。 - `translate_input` 到 `pre_safety_check`:
翻译后,工作流继续检查输入的安全性。 - `pre_safety_check` 条件规则:
- 如果输入通过了预安全检查(`state["error"] == "no"`),工作流将移至 `schema_extract`。
- 如果输入未通过预安全检查,工作流终止(`END`)。
- `schema_extract` 到 `context_check`:
提取模式后,工作流验证输入与数据库上下文的相关性。 - `context_check` 条件规则:
- 如果输入相关(`state["error"] == "no"`),工作流将移至 `generate`。
- 如果不相关,工作流终止(`END`)。
- `generate` 到 `post_safety_check`:
工作流生成 SQL 查询并将其发送进行验证。 - **`post_safety_check` 条件规则**:- 如果输入通过了后安全检查(`state["error"] == "no"`),工作流将移至 `sql_check`。- 如果输入未通过后安全检查,工作流终止(`END`)。
- `sql_check` 条件规则:
- 如果查询有效,工作流将继续执行 `run_query`。
- 如果查询需要调整,并且迭代次数低于 3 次,工作流将循环回到 `generate`。
- 如果验证失败且迭代次数超过 3 次,工作流终止(`END`)。
- `run_query` 到 `END`:
查询执行后,工作流结束。
上述模式提供了 LangGraph 节点和边的表示
在 MLflow 中记录模型
现在我们已经使用 LangGraph 构建了一个多语言查询引擎,我们已经准备好使用 MLflow 的 代码中的模型 (Models from Code) 来记录模型。将模型记录到 MLflow 中允许我们将多语言查询引擎视为传统 ML 模型,从而可以在各种服务基础设施中无缝跟踪、版本控制和打包以进行部署。MLflow 的代码中的模型策略(我们记录代表模型的代码)与 MLflow 基于对象的日志记录(创建、序列化模型对象并将其记录为 pickle 或 JSON 对象)形成对比。
步骤 1:创建我们的代码中的模型文件
到目前为止,我们在名为 `workflow.py` 的文件中定义了 `get_workflow` 函数。在此步骤中,我们将创建一个新文件 `sql_model.py`,其中引入了 `SQLGenerator` 类。此脚本将
- 从 `workflow.py` 导入 `get_workflow` 函数。
- 将 `SQLGenerator` 类定义为 `PythonModel`,包括一个利用 `get_workflow` 函数来初始化 LangGraph 工作流的 predict 方法。
- 使用 `mlflow.models.set_model` 将 `SQLGenerator PythonModel` 类指定为 MLflow 关注的模型。
import mlflow
from definitions import REMOTE_SERVER_URI
from workflow import get_workflow
mlflow.set_tracking_uri(REMOTE_SERVER_URI)
class SQLGenerator(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input):
return get_workflow(
model_input["conn"], model_input["cursor"], model_input["vector_store"]
)
mlflow.models.set_model(SQLGenerator())
步骤 2:使用 MLflow 代码中的模型功能进行记录
定义了 SQLGenerator 自定义 Python 模型后,下一步是使用 MLflow 的“从代码生成模型”(Models from Code)功能将其记录下来。这涉及到使用标准的 log model API,指定 `sql_model.py` 脚本的路径,并使用 `code_paths` 参数将 `workflow.py` 包含为依赖项。这种方法确保在不同环境或另一台机器上加载模型时,所有必需的代码文件都可用。
import mlflow
from definitions import (
EXPERIMENT_NAME,
MODEL_ALIAS,
REGISTERED_MODEL_NAME,
REMOTE_SERVER_URI,
)
from mlflow import MlflowClient
client = MlflowClient(tracking_uri=REMOTE_SERVER_URI)
mlflow.set_tracking_uri(REMOTE_SERVER_URI)
mlflow.set_experiment(EXPERIMENT_NAME)
with mlflow.start_run():
logged_model_info = mlflow.pyfunc.log_model(
python_model="sql_model.py",
artifact_path="sql_generator",
registered_model_name=REGISTERED_MODEL_NAME,
code_paths=["workflow.py"],
)
client.set_registered_model_alias(
REGISTERED_MODEL_NAME, MODEL_ALIAS, logged_model_info.registered_model_version
)
在 MLflow UI 中,存储的模型将 `sql_model.py` 和 `workflow.py` 脚本作为工件包含在运行中。这种从代码记录模型的功能不仅记录了模型的参数和指标,还捕获了定义其功能的代码。这确保了可观测性、无缝跟踪以及直接通过 UI 进行的简单调试。但是,务必确保敏感元素(如 API 密钥或凭证)绝不硬编码到这些脚本中。由于代码按原样存储,包含任何敏感信息都可能导致令牌泄露并带来安全风险。相反,应使用环境变量、密钥管理系统或其他安全方法来安全地管理敏感数据。

在 main.py 中使用已记录的多语言查询引擎
在记录模型后,可以使用模型 URI 和标准的 mlflow.pyfunc.load_model API 将其从 MLflow 跟踪服务器加载回来。加载模型时,将执行 `workflow.py` 脚本以及 `sql_model.py` 脚本,确保在调用 `predict` 方法时 `get_workflow` 函数可用。
通过执行下面的代码,我们证明了我们的多语言查询引擎能够执行自然语言到 SQL 的生成和查询执行。
import os
import logging
import mlflow
from database import setup_database
from definitions import (
EXPERIMENT_NAME,
MODEL_ALIAS,
REGISTERED_MODEL_NAME,
REMOTE_SERVER_URI,
)
from dotenv import load_dotenv
from vector_store import setup_vector_store
mlflow.set_tracking_uri(REMOTE_SERVER_URI)
mlflow.set_experiment(EXPERIMENT_NAME)
mlflow.langchain.autolog()
# Initialize the logger
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
_logger.addHandler(handler)
def main():
# Load environment variables from .env file
load_dotenv()
# Access secrets using os.getenv
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
# Setup database and vector store
conn = setup_database()
cursor = conn.cursor()
vector_store = setup_vector_store()
# Load the model
model_uri = f"models:/{REGISTERED_MODEL_NAME}@{MODEL_ALIAS}"
model = mlflow.pyfunc.load_model(model_uri)
model_input = {"conn": conn, "cursor": cursor, "vector_store": vector_store}
app = model.predict(model_input)
# save image
app.get_graph().draw_mermaid_png(
output_file_path="sql_agent_with_safety_checks.png"
)
# Example user interaction
_logger.info("Welcome to the SQL Assistant!")
while True:
question = input("\nEnter your SQL question (or type 'exit' to quit): ")
if question.lower() == "exit":
break
# Initialize the state with all required keys
initial_state = {
"messages": [("user", question)],
"iterations": 0,
"error": "",
"results": None,
"generation": None,
"no_records_found": False,
"translated_input": "", # Initialize translated_input
}
solution = app.invoke(initial_state)
# Check if an error was set during the safety check
if solution["error"] == "yes":
_logger.info("\nAssistant Message:\n")
_logger.info(solution["messages"][-1][1]) # Display the assistant's message
continue # Skip to the next iteration
# Extract the generated SQL query from solution["generation"]
sql_query = solution["generation"].sql_code
_logger.info("\nGenerated SQL Query:\n")
_logger.info(sql_query)
# Extract and display the query results
if solution.get("no_records_found"):
_logger.info("\nNo records found matching your query.")
elif "results" in solution and solution["results"] is not None:
_logger.info("\nQuery Results:\n")
for row in solution["results"]:
_logger.info(row)
else:
_logger.info("\nNo results returned or query did not execute successfully.")
_logger.info("Goodbye!")
if __name__ == "__main__":
main()
项目文件结构
该项目遵循简单的文件结构。
MLflow 跟踪
MLflow 自动跟踪为 LangChain、OpenAI、LlamaIndex、DSPy 和 AutoGen 等各种 GenAI 库提供了完全自动化的集成。由于我们的 AI 工作流是使用 LangGraph 构建的,我们可以通过启用 mlflow.langchain.autolog() 来激活自动 LangChain 跟踪。
通过 LangChain 自动日志记录,当对链调用调用 API 时,跟踪会自动记录到当前的 MLflow 实验中。这种无缝集成可确保捕获每一次交互以供监视和分析。
在 MLflow 中查看跟踪
通过导航到感兴趣的 MLflow 实验并单击“跟踪”(Tracing)选项卡,可以轻松访问跟踪。进入后,选择特定跟踪可提供详细的执行信息。
每个跟踪包括:
- 执行图:工作流步骤的可视化。
- 输入和输出:每个步骤处理的数据的详细日志。
通过利用 MLflow 跟踪,我们可以获得对整个图执行的精细可见性。AI 工作流图通常感觉像一个黑盒子,使得调试和理解每个步骤发生的情况变得困难。但是,只需一行代码即可启用跟踪,MLflow 即可提供对工作流的清晰详细的见解,使开发人员能够有效地调试、监视和优化图的每个节点,确保我们的多语言查询引擎保持透明、可审计和可扩展。

结论
在本文中,我们探讨了使用 LangGraph 和 MLflow 构建和管理多语言查询引擎的过程。通过将 LangGraph 的动态 AI 工作流与 MLflow 强大的生命周期管理和跟踪功能集成,我们创建了一个系统,该系统不仅能够生成和执行准确高效的自然语言到 SQL 的转换,而且还具有透明度和可扩展性。


