跳到主要内容

从自然语言到 SQL:构建和追踪多语言查询引擎

·40分钟阅读
Hugo Carvalho
阿迪达斯机器学习分析师
Joana Ferreira
阿迪达斯机器学习工程师
Rahul Pandey
Adidas 高级解决方案架构师

如果您正在构建一个多语言查询引擎,该引擎将自然语言到SQL的生成与查询执行相结合,并充分利用MLflow的各项功能,那么这篇博文就是您的指南。我们将探讨如何利用MLflow Models from Code实现AI工作流的无缝跟踪和版本控制。此外,我们将深入探讨MLflow的Tracing功能,通过跟踪AI工作流中每个中间步骤的输入、输出和元数据,为这些组件带来可观测性。

简介

SQL是关系数据库管理和数据访问的基础技能。然而,构建复杂的SQL查询来回答棘手的数据问题可能既困难又耗时。这种复杂性使得有效利用数据变得困难。自然语言到SQL(NL2SQL)系统通过提供从自然语言到SQL命令的翻译来解决这个问题,使非技术人员能够与数据交互:用户可以简单地用他们熟悉的自然语言提问,这些系统将帮助他们获取所需的信息。

然而,在创建NL2SQL系统时,仍然存在一些问题,例如语义歧义、模式映射或错误处理和用户反馈。因此,在构建此类系统时,我们必须设置一些防护措施,而不是完全依赖LLM。

在这篇博文中,我们将引导您完成构建多语言查询引擎的过程。该引擎支持多种语言的自然语言输入,根据翻译后的用户输入生成SQL查询,并执行该查询。让我们看一个例子:使用包含公司客户、产品和订单信息的数据库,用户可以用任何语言提出问题,例如“Quantos clientes temos por país?”(葡萄牙语,意为“我们有多少客户按国家/地区划分?”)。AI工作流会将输入翻译成英语,输出“How many customers do we have per country?”。然后,它会验证输入的安全性,检查是否可以使用数据库模式回答问题,生成适当的SQL查询(例如,SELECT COUNT(CustomerID) AS NumberOfCustomers, Country FROM Customers GROUP BY Country;),并验证查询以确保不存在有害命令(例如,DROP)。最后,它会在数据库上执行查询以检索结果。

我们将首先演示如何利用LangGraph的功能来构建一个动态的AI工作流。该工作流集成了OpenAI和外部数据源,如Vector Store和SQLite数据库,以处理用户输入、执行安全检查、查询数据库并生成有意义的响应。

在这篇博文中,我们将利用MLflow的Models from Code功能来实现AI工作流的无缝跟踪和版本控制。此外,我们将深入探讨MLflow的Tracing功能,该功能通过跟踪AI工作流中各个中间步骤的输入、输出和相关元数据,增强了其可观测性。这使得识别错误和意外行为变得容易,从而提高了工作流的透明度。

先决条件

要设置和运行此项目,请确保已安装以下Python包

  • faiss-cpu
  • langchain
  • langchain-core
  • langchain-openai
  • langgraph
  • langchain-community
  • pydantic >=2
  • typing_extensions
  • python-dotenv

此外,还需要一个MLflow Tracking Server来有效地记录和管理实验、模型和跟踪。有关本地设置,请参阅官方MLflow文档,了解配置简单的MLflow Tracking Server的说明。

最后,请确保您的OpenAI API密钥已保存在项目目录中的.env文件中。这使得应用程序能够安全地访问构建AI工作流所需的OpenAI服务。`.env`文件应包含类似以下行:

OPENAI_API_KEY=your_openai_api_key

使用LangGraph构建的多语言查询引擎

多语言查询引擎利用LangGraph库,这是一个AI编排工具,旨在为由LLM驱动的应用程序创建有状态的、多代理的、循环的图结构。

与其他AI编排器相比,LangGraph提供了三个核心优势:循环、可控性和持久性。它允许定义带有循环的AI工作流,这对于实现重试机制至关重要,例如多语言查询引擎中的SQL查询生成重试(当验证失败时,查询会循环回以重新生成)。这使得LangGraph成为构建我们的多语言查询引擎的理想工具。

LangGraph的关键特性:

  1. 有状态架构:引擎维护着图执行状态的动态快照。这个快照充当节点之间的共享资源,能够实现高效的决策和每个节点执行时的实时更新。

  2. 多代理设计:AI工作流在整个工作流中包含与OpenAI和其他外部工具的多次交互。

  3. 循环图结构:图的循环性质引入了一个健壮的重试机制。当需要时,此机制会通过循环回先前阶段来动态处理故障,确保图的连续执行。(该机制的详细信息将在后面讨论。)

AI工作流概述

多语言查询引擎的先进AI工作流由相互连接的节点和边组成,每个节点和边代表一个关键阶段。

  1. 翻译节点:将用户输入翻译成英语。

  2. 预安全检查:确保用户输入不包含毒性或不当内容,并且不包含有害的SQL命令(例如,DELETEDROP)。

  3. 数据库模式提取:检索目标数据库的模式,以了解其结构和可用数据。

  4. 相关性验证:将用户输入与数据库模式进行比较,以确保与数据库上下文保持一致。

  5. SQL查询生成:根据用户输入和当前数据库模式生成SQL查询。

  6. 后安全检查:确保生成的SQL查询不包含有害的SQL命令(例如,DELETEDROP)。

  7. SQL查询验证:在回滚安全的环境中执行SQL查询,以确保其有效性,然后再运行。

  8. 动态状态评估:根据当前状态确定下一步操作。如果SQL查询验证失败,则会循环回第5阶段以重新生成查询。

  9. 查询执行和结果检索:执行SQL查询,如果它是SELECT语句,则返回结果。

重试机制在第8阶段引入,此时系统会动态评估当前图的状态。具体来说,当SQL查询验证节点(第7阶段)检测到问题时,状态会触发循环回SQL生成节点(第5阶段)以进行新的SQL生成尝试(最多3次)。

组件

多语言查询引擎与多个外部组件交互,将自然语言用户输入转换为SQL查询并安全、可靠地执行它们。在本节中,我们将详细介绍关键的AI工作流组件:OpenAI、Vector Store、SQLite数据库和SQL生成链。

OpenAI

OpenAI,更具体地说,是gpt-4o-mini语言模型,在工作流的多个阶段发挥着至关重要的作用。它提供了执行以下操作所需的智能:

  1. 翻译:将用户输入翻译成英语。如果文本已经是英语,它会简单地重复输入。

  2. 安全检查:分析用户输入,确保其不包含毒性或不当内容。

  3. 相关性检查:评估用户的问题是否与数据库模式相关。

  4. SQL生成:根据用户输入、SQL生成文档和数据库模式生成有效且可执行的SQL查询。

关于OpenAI实现的详细信息将在后面的节点描述部分提供。

FAISS Vector Store

为了构建一个有效的自然语言到SQL引擎,能够生成准确且可执行的SQL查询,我们利用Langchain的FAISS Vector Store功能。这种设置允许系统从之前存储在向量数据库中的W3Schools SQL文档中搜索和提取SQL查询生成指南,从而提高SQL查询生成的成功率。

出于演示目的,我们使用的是FAISS,这是一个内存中的向量存储,向量直接存储在RAM中。这提供了快速访问,但也意味着数据在运行之间不会持久化。对于一个更具可扩展性的解决方案,允许在多个项目之间存储和共享嵌入,我们推荐使用AWS OpenSearch、Vertex AI Vector SearchAzure Vector SearchMosaic AI Vector Search等替代方案。这些基于云的解决方案提供持久存储、自动扩展以及与其他云服务的无缝集成,使其非常适合大规模应用。

步骤1:加载SQL文档

创建带有SQL查询生成指南的FAISS Vector Store的第一步是使用LangChain的RecursiveUrlLoaderW3Schools SQL页面加载SQL文档。此工具检索文档,允许我们将其用作引擎的知识库。

步骤2:将文本分割成易于管理的块

加载的SQL文档是一段很长的文本,很难被LLM有效摄取。为了解决这个问题,下一步是使用Langchain的RecursiveCharacterTextSplitter将文本分割成更小、更易于管理的块。通过将文本分割成500个字符的块,并设置50个字符的重叠,我们确保语言模型拥有足够的上下文,同时最大限度地减少丢失跨块重要信息的风险。split_text方法应用此分割过程,将结果存储在名为“documents”的列表中。

步骤3:生成嵌入模型

第三步是创建一个模型,将这些块转换为嵌入(每个文本块的向量化数值表示)。嵌入使系统能够比较块与用户输入之间的相似性,从而检索最相关的匹配项以进行SQL查询生成。

步骤4:创建嵌入并将其存储在FAISS Vector Store中

最后,我们使用FAISS创建并存储嵌入。FAISS.from_texts方法接收所有块,计算它们的嵌入,并将它们存储在一个高速可搜索的向量数据库中。这个可搜索的数据库允许引擎有效地检索相关的SQL指南,显著提高了可执行SQL查询生成的成功率。

import logging
import os

from bs4 import BeautifulSoup as Soup

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings


def setup_vector_store(logger: logging.Logger):
"""Setup or load the vector store."""
if not os.path.exists("data"):
os.makedirs("data")

vector_store_dir = "data/vector_store"

if os.path.exists(vector_store_dir):
# Load the vector store from disk
logger.info("Loading vector store from disk...")
vector_store = FAISS.load_local(
vector_store_dir,
OpenAIEmbeddings(),
allow_dangerous_deserialization=True,
)
else:
logger.info("Creating new vector store...")
# Load SQL documentation
url = "https://w3schools.org.cn/sql/"
loader = RecursiveUrlLoader(
url=url, max_depth=2, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()

# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
)

documents = []
for doc in docs:
splits = text_splitter.split_text(doc.page_content)
for i, split in enumerate(splits):
documents.append(
{
"content": split,
"metadata": {"source": doc.metadata["source"], "chunk": i},
}
)

# Compute embeddings and create vector store
embedding_model = OpenAIEmbeddings()
vector_store = FAISS.from_texts(
[doc["content"] for doc in documents],
embedding_model,
metadatas=[doc["metadata"] for doc in documents],
)

# Save the vector store to disk
vector_store.save_local(vector_store_dir)
logger.info("Vector store created and saved to disk.")

return vector_store

SQLite数据库

SQLite数据库是多语言查询引擎的关键组成部分,它充当结构化数据存储库。SQLite提供了一个轻量级、快速且独立的relational database engine,无需服务器设置或安装。其紧凑的尺寸(低于500KB)和零配置的特性使其非常易于使用,而其平台无关的数据库格式确保了跨不同系统的无缝可移植性。作为一个本地磁盘数据库,SQLite是避免设置MySQL或PostgreSQL复杂性的理想选择,同时仍提供一个可靠的、功能齐全的SQL引擎和出色的性能。

SQLite数据库通过启用以下功能来支持高效的SQL查询生成、验证和执行:

  1. 模式提取:为用户输入上下文验证(第4阶段)和可执行SQL查询生成(第5阶段)提供模式信息。

  2. 查询执行:在验证阶段(第7阶段)和查询执行阶段(第9阶段)在回滚安全的环境中执行SQL查询,为SELECT语句获取结果,并为其他查询类型提交更改。

SQLite数据库初始化

当AI工作流初始化时,使用setup_database函数来初始化数据库。此过程包括:

  1. 设置SQLite数据库连接:建立到SQLite数据库的连接,以实现数据交互。

  2. 表创建:定义并创建AI工作流所需的数据库表。

  3. 数据填充:用样本数据填充表,以支持查询执行和验证阶段。

import logging
import os

import sqlite3


def create_connection(db_file="data/database.db"):
"""Create a database connection to the SQLite database."""
conn = sqlite3.connect(db_file)
return conn


def create_tables(conn):
"""Create tables in the database."""
cursor = conn.cursor()
# Create Customers table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Customers (
CustomerID INTEGER PRIMARY KEY,
CustomerName TEXT,
ContactName TEXT,
Address TEXT,
City TEXT,
PostalCode TEXT,
Country TEXT
)
"""
)

# Create Orders table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Orders (
OrderID INTEGER PRIMARY KEY,
CustomerID INTEGER,
OrderDate TEXT,
FOREIGN KEY (CustomerID) REFERENCES Customers (CustomerID)
)
"""
)

# Create OrderDetails table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS OrderDetails (
OrderDetailID INTEGER PRIMARY KEY,
OrderID INTEGER,
ProductID INTEGER,
Quantity INTEGER,
FOREIGN KEY (OrderID) REFERENCES Orders (OrderID),
FOREIGN KEY (ProductID) REFERENCES Products (ProductID)
)
"""
)

# Create Products table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Products (
ProductID INTEGER PRIMARY KEY,
ProductName TEXT,
Price REAL
)
"""
)

conn.commit()


def populate_tables(conn):
"""Populate tables with sample data if they are empty."""
cursor = conn.cursor()

# Populate Customers table if empty
cursor.execute("SELECT COUNT(*) FROM Customers")
if cursor.fetchone()[0] == 0:
customers = []
for i in range(1, 51):
customers.append(
(
i,
f"Customer {i}",
f"Contact {i}",
f"Address {i}",
f"City {i % 10}",
f"{10000 + i}",
f"Country {i % 5}",
)
)
cursor.executemany(
"""
INSERT INTO Customers (CustomerID, CustomerName, ContactName, Address, City, PostalCode, Country)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
customers,
)

# Populate Products table if empty
cursor.execute("SELECT COUNT(*) FROM Products")
if cursor.fetchone()[0] == 0:
products = []
for i in range(1, 51):
products.append((i, f"Product {i}", round(10 + i * 0.5, 2)))
cursor.executemany(
"""
INSERT INTO Products (ProductID, ProductName, Price)
VALUES (?, ?, ?)
""",
products,
)

# Populate Orders table if empty
cursor.execute("SELECT COUNT(*) FROM Orders")
if cursor.fetchone()[0] == 0:
orders = []
from datetime import datetime, timedelta

base_date = datetime(2023, 1, 1)
for i in range(1, 51):
order_date = base_date + timedelta(days=i)
orders.append(
(
i,
i % 50 + 1, # CustomerID between 1 and 50
order_date.strftime("%Y-%m-%d"),
)
)
cursor.executemany(
"""
INSERT INTO Orders (OrderID, CustomerID, OrderDate)
VALUES (?, ?, ?)
""",
orders,
)

# Populate OrderDetails table if empty
cursor.execute("SELECT COUNT(*) FROM OrderDetails")
if cursor.fetchone()[0] == 0:
order_details = []
for i in range(1, 51):
order_details.append(
(
i,
i % 50 + 1, # OrderID between 1 and 50
i % 50 + 1, # ProductID between 1 and 50
(i % 5 + 1) * 2, # Quantity between 2 and 10
)
)
cursor.executemany(
"""
INSERT INTO OrderDetails (OrderDetailID, OrderID, ProductID, Quantity)
VALUES (?, ?, ?, ?)
""",
order_details,
)

conn.commit()


def setup_database(logger: logging.Logger):
"""Setup the database and return the connection."""
db_file = "data/database.db"
if not os.path.exists("data"):
os.makedirs("data")

db_exists = os.path.exists(db_file)

conn = create_connection(db_file)

if not db_exists:
logger.info("Setting up the database...")
create_tables(conn)
populate_tables(conn)
else:
logger.info("Database already exists. Skipping setup.")

return conn

SQL生成链

SQL生成链sql_gen_chain)是我们工作流中自动化SQL查询生成的支柱。该链利用LangChain的模块化功能和OpenAI的高级自然语言处理能力,将用户问题转化为精确且可执行的SQL查询。

核心功能:

  • 提示驱动生成:从精心设计的提示开始,该提示集成了数据库模式和文档片段,确保查询在上下文上是准确的。

  • 结构化响应:以预定义的格式提供输出,包括:

    • 查询目的的描述

    • 对应的、可执行的SQL代码

  • 适应性强且可靠:使用gpt-4o-mini进行健壮、一致的查询生成,最大限度地减少手动工作和错误。

此链是我们工作流中的关键组件,能够实现SQL查询生成与下游流程的无缝集成,确保准确性,并显著提高效率。

from pydantic import BaseModel, Field

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

class SQLQuery(BaseModel):
"""Schema for SQL query solutions to questions."""
description: str = Field(description="Description of the SQL query")
sql_code: str = Field(description="The SQL code block")

def get_sql_gen_chain():
"""Set up the SQL generation chain."""
sql_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a SQL assistant with expertise in SQL query generation. \n
Answer the user's question based on the provided documentation snippets and the database schema provided below. Ensure any SQL query you provide is valid and executable. \n
Structure your answer with a description of the query, followed by the SQL code block. Here are the documentation snippets:\n{retrieved_docs}\n\nDatabase Schema:\n{database_schema}""",
),
("placeholder", "{messages}"),
]
)

# Initialize the OpenAI LLM
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")

# Create the code generation chain
sql_gen_chain = sql_gen_prompt | llm.with_structured_output(SQLQuery)

return sql_gen_chain

工作流设置和初始化

在深入研究工作流节点之前,设置必要的组件和定义工作流的结构至关重要。本节将介绍 essential 库的初始化、日志记录以及自定义GraphState类,以及主工作流编译函数。

定义GraphState

GraphState类是一个自定义的TypedDict,它在工作流进行时维护状态信息。它充当节点之间的共享数据结构,确保连续性和一致性。关键字段包括:

  • error:跟踪是否发生了错误。
  • messages:存储用户和系统消息的列表。
  • generation:保存生成的SQL查询。
  • iterations:在发生错误时跟踪重试尝试次数。
  • results:存储SQL执行结果(如果有)。
  • no_records_found:如果查询未返回任何记录,则标记。
  • translated_input:包含用户翻译后的输入。
  • database_schema:维护数据库模式以进行上下文验证。
import logging
import re
from typing import List, Optional

from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from sql_generation import get_sql_gen_chain
from typing_extensions import TypedDict

# Initialize the logger
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
_logger.addHandler(handler)

class GraphState(TypedDict):
error: str # Tracks if an error has occurred
messages: List # List of messages (user input and assistant messages)
generation: Optional[str] # Holds the generated SQL query
iterations: int # Keeps track of how many times the workflow has retried
results: Optional[List] # Holds the results of SQL execution
no_records_found: bool # Flag for whether any records were found in the SQL result
translated_input: str # Holds the translated user input
database_schema: str # Holds the extracted database schema for context checking
工作流编译函数

主函数get_workflow负责定义和编译工作流。关键组件包括:

  • conncursor:用于数据库连接和查询执行。
  • vector_store:用于上下文检索的向量数据库。
  • max_iterations:设置重试尝试的限制,以防止无限循环。
  • sql_gen_chain:从sql_generation检索SQL生成链,用于根据上下文输入生成SQL查询。
  • ChatOpenAI:初始化OpenAI gpt-4o-mini模型,用于安全检查和查询翻译等任务。
def get_workflow(conn, cursor, vector_store):
"""Define and compile the LangGraph workflow."""

# Max iterations: defines how many times the workflow should retry in case of errors
max_iterations = 3

# SQL generation chain: this is a chain that will generate SQL based on retrieved docs
sql_gen_chain = get_sql_gen_chain()

# Initialize OpenAI LLM for translation and safety checks
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")

# Define the individual nodes of the workflow

此函数作为使用StateGraph创建完整工作流的入口点。工作流内的各个节点将在后续部分进行定义和连接。

节点描述

1. 翻译输入

translate_input节点将用户查询翻译成英语,以标准化处理并确保与下游节点的兼容性。将用户输入作为AI工作流的第一步进行翻译,可以确保任务的划分并提高可观测性。任务划分通过将翻译与用户输入安全验证和SQL生成等其他下游任务隔离开来,简化了工作流。提高的可观测性在MLflow中提供了清晰的跟踪,便于调试和监控过程。

  • 示例
    • 输入:“Quantos pedidos foram realizados em Novembro?”
    • 翻译后:“How many orders were made in November?”
    • 输入:“Combien de ventes avons-nous enregistrées en France ?”
    • 翻译后:“How many sales did we record in France?”
  • 代码
def translate_input(state: GraphState) -> GraphState:
"""
Translates user input to English using an LLM. If the input is already in English,
it is returned as is. This ensures consistent input for downstream processing.

Args:
state (GraphState): The current graph state containing user messages.

Returns:
GraphState: The updated state with the translated input.
"""
_logger.info("Starting translation of user input to English.")
messages = state["messages"]
user_input = messages[-1][1] # Get the latest user input

# Translation prompt for the model
translation_prompt = f"""
Translate the following text to English. If the text is already in English, repeat it exactly without any additional explanation.

Text:
{user_input}
"""

# Call the OpenAI LLM to translate the text
translated_response = llm.invoke(translation_prompt)
translated_text = translated_response.content.strip() # Access the 'content' attribute and strip any extra spaces

# Update state with the translated input
state["translated_input"] = translated_text
_logger.info("Translation completed successfully. Translated input: %s", translated_text)

return state
2. 预安全检查

pre_safety_check节点确保及早检测用户输入中不允许的SQL操作和不当内容。虽然有害SQL命令(例如,CREATEDELETEDROPINSERTUPDATE)的检查将在工作流稍后,即生成SQL查询后再次进行,但此预安全检查对于在输入阶段识别潜在问题至关重要。通过这样做,它可以防止不必要的计算,并立即向用户提供反馈。

虽然使用禁止列表来防止有害SQL操作可以快速防止破坏性查询,但在处理复杂的SQL后端(如T-SQL)时,维护一个全面的禁止列表可能会变得难以管理。一种替代方法是采用允许列表,将查询限制为仅允许安全的操作(例如,SELECTJOIN)。这种方法通过缩小允许的操作范围而不是试图阻止所有风险命令,确保了更强大的解决方案。

为了实现企业级的解决方案,该项目可以利用Unity Catalog等框架,这些框架提供了一种集中且健壮的方法来管理安全相关的功能,例如AI工作流的pre_safety_check。通过在 such a framework 中注册和管理可重用的函数,您可以跨所有AI工作流强制执行一致且可靠的行为,从而提高安全性和可扩展性。

此外,该节点利用LLM分析输入中的冒犯性或不当内容。如果检测到不安全查询或不当内容,状态将更新为错误标志,并提供透明的反馈,从而及早保护工作流免受恶意或破坏性元素的侵害。

  • 示例
  1. 禁止的操作

    • 输入: "DROP TABLE customers;"

    • 响应: "您的查询包含禁止的SQL操作,无法处理。"

    • 输入: _"SELECT _ FROM orders;"*

    • 响应: "查询已允许。"

  2. 不当内容

    • 输入: "Show me orders where customers have names like 'John the Idiot';"
    • 响应: "您的查询包含不当内容,无法处理。"
    • 输入: "Find total sales by region."
    • 响应: "输入安全,可以处理。"
  • 代码
def pre_safety_check(state: GraphState) -> GraphState:
"""
Perform safety checks on the user input to ensure that no dangerous SQL operations
or inappropriate content is present. The function checks for SQL operations like
DELETE, DROP, and others, and also evaluates the input for toxic or unsafe content.

Args:
state (GraphState): The current graph state containing the translated user input.

Returns:
GraphState: The updated state with error status and messages if any issues are found.
"""
_logger.info("Performing safety check.")
translated_input = state["translated_input"]
messages = state["messages"]
error = "no"

# List of disallowed SQL operations (e.g., DELETE, DROP)
disallowed_operations = ['CREATE', 'DELETE', 'DROP', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
pattern = re.compile(r'\b(' + '|'.join(disallowed_operations) + r')\b', re.IGNORECASE)

# Check if the input contains disallowed SQL operations
if pattern.search(translated_input):
_logger.warning("Input contains disallowed SQL operations. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your query contains disallowed SQL operations and cannot be processed.")]
else:
# Check if the input contains inappropriate content
safety_prompt = f"""
Analyze the following input for any toxic or inappropriate content.

Respond with only "safe" or "unsafe", and nothing else.

Input:
{translated_input}
"""
safety_invoke = llm.invoke(safety_prompt)
safety_response = safety_invoke.content.strip().lower() # Access the 'content' attribute and strip any extra spaces

if safety_response == "safe":
_logger.info("Input is safe to process.")
else:
_logger.warning("Input contains inappropriate content. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your query contains inappropriate content and cannot be processed.")]

# Update state with error status and messages
state["error"] = error
state["messages"] = messages

return state
3. 模式提取

schema_extract节点通过查询元数据动态检索数据库模式,包括表名和列详细信息。格式化的模式存储在状态中,以便在验证用户查询的同时适应当前的数据库结构。

  • 示例
    • 输入:请求模式提取。
      模式输出
      • Customers(CustomerID (INTEGER), CustomerName (TEXT), ContactName (TEXT), Address (TEXT), City (TEXT), PostalCode (TEXT), Country (TEXT))
      • Orders(OrderID (INTEGER), CustomerID (INTEGER), OrderDate (TEXT))
      • OrderDetails(OrderDetailID (INTEGER), OrderID (INTEGER), ProductID (INTEGER), Quantity (INTEGER))
      • Products(ProductID (INTEGER), ProductName (TEXT), Price (REAL))
  • 代码
def schema_extract(state: GraphState) -> GraphState:
"""
Extracts the database schema, including all tables and their respective columns,
from the connected SQLite database. This function retrieves the list of tables and
iterates through each table to gather column definitions (name and data type).

Args:
state (GraphState): The current graph state, which will be updated with the database schema.

Returns:
GraphState: The updated state with the extracted database schema.
"""
_logger.info("Extracting database schema.")

# Extract the schema from the database
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema_details = []

# Loop through each table and retrieve column information
for table_name_tuple in tables:
table_name = table_name_tuple[0]
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()

# Format column definitions
column_defs = ', '.join([f"{col[1]} ({col[2]})" for col in columns])
schema_details.append(f"- {table_name}({column_defs})")

# Save the schema in the state
database_schema = '\n'.join(schema_details)
state["database_schema"] = database_schema
_logger.info(f"Database schema extracted:\n{database_schema}")

return state
4. 上下文检查

context_check节点通过将用户查询与提取的数据库模式进行比较来验证查询,以确保一致性和相关性。与模式不对应的查询将被标记为不相关,从而防止资源浪费并使用户能够提供反馈以重新表述查询。

  • 示例
    • 输入:“What is the average order value?”
      模式匹配:输入与数据库模式相关。
    • 输入:“Show me data from the inventory table.”
      响应:“您的查询与数据库无关,无法处理。”
  • 代码
def context_check(state: GraphState) -> GraphState:
"""
Checks whether the user's input is relevant to the database schema by comparing
the user's question with the database schema. Uses a language model to determine if
the question can be answered using the provided schema.

Args:
state (GraphState): The current graph state, which contains the translated input
and the database schema.

Returns:
GraphState: The updated state with error status and messages if the input is irrelevant.
"""
_logger.info("Performing context check.")

# Extract relevant data from the state
translated_input = state["translated_input"]
messages = state["messages"]
error = "no"
database_schema = state["database_schema"] # Get the schema from the state

# Use the LLM to determine if the input is relevant to the database schema
context_prompt = f"""
Determine whether the following user input is a question that can be answered using the database schema provided below.

Respond with only "relevant" if the input is relevant to the database schema, or "irrelevant" if it is not.

User Input:
{translated_input}

Database Schema:
{database_schema}
"""

# Call the LLM for context check
llm_invoke = llm.invoke(context_prompt)
llm_response = llm_invoke.content.strip().lower() # Access the 'content' attribute and strip any extra spaces and lower case

# Process the response from the LLM
if llm_response == "relevant":
_logger.info("Input is relevant to the database schema.")
else:
_logger.info("Input is not relevant. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your question is not related to the database and cannot be processed.")]

# Update the state with error and messages
state["error"] = error
state["messages"] = messages

return state
5. 生成

generate节点通过从向量存储中检索相关文档并利用预定义的SQL生成链,根据自然语言输入构建SQL查询。它将查询与用户的意图和模式上下文对齐,并使用生成的SQL及其描述更新状态。

  • 示例
    • 输入:“Find total sales.”
      生成的SQL:“SELECT SUM(Products.Price * OrderDetails.Quantity) AS TotalSales FROM OrderDetails LEFT JOIN Products ON OrderDetails.ProductID = Products.ProductID;
    • 输入:“List all customers in New York.”
      生成的SQL:“SELECT name FROM customers WHERE location = 'New York';
  • 代码
def generate(state: GraphState) -> GraphState:
"""
Generates an SQL query based on the user's input. The node retrieves relevant documents from
the vector store and uses a generation chain to produce an SQL query.

Args:
state (GraphState): The current graph state, which contains the translated input and
other relevant data such as messages and iteration count.

Returns:
GraphState: The updated state with the generated SQL query and related messages.
"""
_logger.info("Generating SQL query.")

# Extract relevant data from the state
messages = state["messages"]
iterations = state["iterations"]
translated_input = state["translated_input"]
database_schema = state["database_schema"]

# Retrieve relevant documents from the vector store based on the translated user input
docs = vector_store.similarity_search(translated_input, k=4)
retrieved_docs = "\n\n".join([doc.page_content for doc in docs])

# Generate the SQL query using the SQL generation chain
sql_solution = sql_gen_chain.invoke(
{
"retrieved_docs": retrieved_docs,
"database_schema": database_schema,
"messages": [("user", translated_input)],
}
)

# Save the generated SQL query in the state
messages += [
(
"assistant",
f"{sql_solution.description}\nSQL Query:\n{sql_solution.sql_code}",
)
]
iterations += 1

# Log the generated SQL query
_logger.info("Generated SQL query:\n%s", sql_solution.sql_code)

# Update the state with the generated SQL query and updated message list
state["generation"] = sql_solution
state["messages"] = messages
state["iterations"] = iterations

return state
6. 后安全检查

post_safety_check节点通过执行最终的有害SQL命令验证,确保生成的SQL查询是安全的。虽然之前的预安全检查识别了用户输入中的禁止操作,但此后安全检查验证了查询生成后产生的SQL查询是否符合安全指南。这种两步方法确保即使在查询生成过程中无意中引入了禁止的操作,它们也能被捕获和标记。如果检测到不安全查询,节点将停止工作流,用错误标志更新状态,并向用户提供反馈。

  • 示例
  1. 禁止的操作
    • 生成的查询: "DROP TABLE orders;"
    • 响应: "生成的SQL查询包含禁止的SQL操作:DROP,无法处理。"
    • 生成的查询: "SELECT name FROM customers;"
    • 响应: "查询有效。"
  • 代码
def post_safety_check(state: GraphState) -> GraphState:
"""
Perform safety checks on the generated SQL query to ensure that it doesn't contain disallowed operations
such as CREATE, DELETE, DROP, etc. This node checks the SQL query generated earlier in the workflow.

Args:
state (GraphState): The current graph state containing the generated SQL query.

Returns:
GraphState: The updated state with error status and messages if any issues are found.
"""
_logger.info("Performing post-safety check on the generated SQL query.")

# Retrieve the generated SQL query from the state
sql_solution = state.get("generation", {})
sql_query = sql_solution.get("sql_code", "").strip()
messages = state["messages"]
error = "no"

# List of disallowed SQL operations
disallowed_operations = ['CREATE', 'DELETE', 'DROP', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
pattern = re.compile(r'\b(' + '|'.join(disallowed_operations) + r')\b', re.IGNORECASE)

# Check if the generated SQL query contains disallowed SQL operations
found_operations = pattern.findall(sql_query)
if found_operations:
_logger.warning(
"Generated SQL query contains disallowed SQL operations: %s. Halting the workflow.",
", ".join(set(found_operations))
)
error = "yes"
messages += [("assistant", f"The generated SQL query contains disallowed SQL operations: {', '.join(set(found_operations))} and cannot be processed.")]
else:
_logger.info("Generated SQL query passed the safety check.")

# Update state with error status and messages
state["error"] = error
state["messages"] = messages

return state
7. SQL检查

sql_check节点通过在事务性保存点内执行生成的SQL查询,确保其安全且语法有效。验证后,所有更改都将被回滚,并会标记错误并提供详细的反馈以维护查询的完整性。

  • 示例
    • 输入SQL:“SELECT name FROM customers WHERE city = 'New York';
      验证:查询有效。
    • 输入SQL:“SELECT MONTH(date) AS month, SUM(total) AS total_sales FROM orders GROUP BY MONTH(date);
      响应:“您的SQL查询执行失败:不存在函数 MONTH。”
  • 代码
def sql_check(state: GraphState) -> GraphState:
"""
Validates the generated SQL query by attempting to execute it on the database.
If the query is valid, the changes are rolled back to ensure no data is modified.
If there is an error during execution, the error is logged and the state is updated accordingly.

Args:
state (GraphState): The current graph state, which contains the generated SQL query
and the messages to communicate with the user.

Returns:
GraphState: The updated state with error status and messages if the query is invalid.
"""
_logger.info("Validating SQL query.")

# Extract relevant data from the state
messages = state["messages"]
sql_solution = state["generation"]
error = "no"

sql_code = sql_solution.sql_code.strip()

try:
# Start a savepoint for the transaction to allow rollback
conn.execute('SAVEPOINT sql_check;')
# Attempt to execute the SQL query
cursor.execute(sql_code)
# Roll back to the savepoint to undo any changes
conn.execute('ROLLBACK TO sql_check;')
_logger.info("SQL query validation: success.")
except Exception as e:
# Roll back in case of error
conn.execute('ROLLBACK TO sql_check;')
_logger.error("SQL query validation failed. Error: %s", e)
messages += [("user", f"Your SQL query failed to execute: {e}")]
error = "yes"

# Update the state with the error status
state["error"] = error
state["messages"] = messages

return state
8. 运行查询

run_query节点执行已验证的SQL查询,连接到数据库以检索结果。它使用查询输出更新状态,确保数据已格式化以供进一步分析或报告,同时实施健壮的错误处理。

  • 示例
    • 输入SQL:“SELECT COUNT(*) FROM Customers WHERE City = 'New York';
      查询结果:“(0,)
    • 输入SQL:“SELECT SUM(Products.Price * OrderDetails.Quantity) AS TotalSales FROM OrderDetails LEFT JOIN Products ON OrderDetails.ProductID = Products.ProductID;
      查询结果:“(6925.0,)
  • 代码
def run_query(state: GraphState) -> GraphState:
"""
Executes the generated SQL query on the database and retrieves the results if it is a SELECT query.
For non-SELECT queries, commits the changes to the database. If no records are found for a SELECT query,
the `no_records_found` flag is set to True.

Args:
state (GraphState): The current graph state, which contains the generated SQL query and other relevant data.

Returns:
GraphState: The updated state with the query results, or a flag indicating if no records were found.
"""
_logger.info("Running SQL query.")

# Extract the SQL query from the state
sql_solution = state["generation"]
sql_code = sql_solution.sql_code.strip()
results = None
no_records_found = False # Flag to indicate no records found

try:
# Execute the SQL query
cursor.execute(sql_code)

# For SELECT queries, fetch and store results
if sql_code.upper().startswith("SELECT"):
results = cursor.fetchall()
if not results:
no_records_found = True
_logger.info("SQL query execution: success. No records found.")
else:
_logger.info("SQL query execution: success.")
else:
# For non-SELECT queries, commit the changes
conn.commit()
_logger.info("SQL query execution: success. Changes committed.")
except Exception as e:
_logger.error("SQL query execution failed. Error: %s", e)

# Update the state with results and flag for no records found
state["results"] = results
state["no_records_found"] = no_records_found

return state
决策步骤:确定下一步操作

decide_next_step函数作为工作流中的一个控制点,根据当前状态决定下一步应采取什么操作。它评估error状态和到目前为止执行的迭代次数,以确定是应该运行查询、完成工作流,还是系统应该重试生成SQL查询。

  • 过程

    • 如果没有错误(error == "no"),系统将继续运行SQL查询。
    • 如果已达到最大迭代次数(max_iterations),则工作流结束。
    • 如果发生错误且未达到最大迭代次数,系统将重试查询生成。
  • 工作流决策示例

    • 无错误,继续执行查询:如果在先前步骤中未发现错误,工作流将继续执行查询。
    • 达到最大迭代次数,结束工作流:如果工作流已尝试了设定的次数(max_iterations),则终止。
    • 检测到错误,重试SQL生成:如果发生错误且系统尚未达到重试限制,它将尝试重新生成SQL查询。
  • 代码

def decide_next_step(state: GraphState) -> str:
"""
Determines the next step in the workflow based on the current state, including whether the query
should be run, the workflow should be finished, or if the query generation needs to be retried.

Args:
state (GraphState): The current graph state, which contains error status and iteration count.

Returns:
str: The next step in the workflow, which can be "run_query", "generate", or END.
"""
_logger.info("Deciding next step based on current state.")

error = state["error"]
iterations = state["iterations"]

if error == "no":
_logger.info("Error status: no. Proceeding with running the query.")
return "run_query"
elif iterations >= max_iterations:
_logger.info("Maximum iterations reached. Ending the workflow.")
return END
else:
_logger.info("Error detected. Retrying SQL query generation.")
return "generate"

工作流编排和条件逻辑

为了定义系统中任务的编排,我们使用StateGraph类构建了一个工作流图。每个任务都表示为一个节点,任务之间的转换定义为边。条件边用于根据工作流的状态控制流程。

def get_workflow(conn, cursor, vector_store):
"""Define and compile the LangGraph workflow."""

# Max iterations: defines how many times the workflow should retry in case of errors
max_iterations = 3

# SQL generation chain: this is a chain that will generate SQL based on retrieved docs
sql_gen_chain = get_sql_gen_chain()

# Initialize OpenAI LLM for translation and safety checks
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")

# Define the individual nodes of the workflow
... # Insert nodes code defined above here

# Build the workflow graph
workflow = StateGraph(GraphState)

# Define workflow nodes
workflow.add_node("translate_input", translate_input) # Translate user input to structured format
workflow.add_node("pre_safety_check", pre_safety_check) # Perform a pre-safety check on input
workflow.add_node("schema_extract", schema_extract) # Extract the database schema
workflow.add_node("context_check", context_check) # Validate input relevance to context
workflow.add_node("generate", generate) # Generate SQL query
workflow.add_node("post_safety_check", post_safety_check) # Perform a post-safety check on generated SQL query
workflow.add_node("sql_check", sql_check) # Validate the generated SQL query
workflow.add_node("run_query", run_query) # Execute the SQL query

# Define workflow edges
workflow.add_edge(START, "translate_input") # Start at the translation step
workflow.add_edge("translate_input", "pre_safety_check") # Move to safety checks

# Conditional edge after safety check
workflow.add_conditional_edges(
"pre_safety_check", # Start at the pre_safety_check node
lambda state: "schema_extract" if state["error"] == "no" else END, # Decide next step
{"schema_extract": "schema_extract", END: END} # Map states to nodes
)

workflow.add_edge("schema_extract", "context_check") # Proceed to context validation

# Conditional edge after context check
workflow.add_conditional_edges(
"context_check", # Start at the context_check node
lambda state: "generate" if state["error"] == "no" else END, # Decide next step
{"generate": "generate", END: END}
)

workflow.add_edge("generate", "post_safety_check") # Proceed to post-safety check

# Conditional edge after post-safety check
workflow.add_conditional_edges(
"post_safety_check", # Start at the post_safety_check node
lambda state: "sql_check" if state["error"] == "no" else END, # If no error, proceed to sql_check, else END
{"sql_check": "sql_check", END: END},
)

# Conditional edge after SQL validation
workflow.add_conditional_edges(
"sql_check", # Start at the sql_check node
decide_next_step, # Function to determine the next step
{
"run_query": "run_query", # If SQL is valid, execute the query
"generate": "generate", # If retry is needed, go back to generation
END: END # Otherwise, terminate the workflow
}
)

workflow.add_edge("run_query", END) # Final step is to end the workflow

# Compile and return the workflow application
app = workflow.compile()

return app
  1. 从开始到translate_input:
    工作流首先将用户输入翻译成结构化格式。
  2. translate_inputpre_safety_check:
    翻译后,工作流继续检查输入的安全性。
  3. pre_safety_check条件规则:
    • 如果输入通过了预安全检查(state["error"] == "no"),工作流将进入schema_extract
    • 如果输入未通过预安全检查,工作流将终止(END)。
  4. schema_extractcontext_check:
    提取模式后,工作流将验证输入与数据库上下文的相关性。
  5. context_check条件规则:
    • 如果输入相关(state["error"] == "no"),工作流将进入generate
    • 如果无关,工作流将终止(END)。
  6. generatepost_safety_check:
    工作流生成SQL查询并将其发送进行验证。\
  7. post_safety_check条件规则:-如果输入通过了后安全检查(state["error"] == "no"),工作流将进入sql_check。-如果输入未通过后安全检查,工作流将终止(END)。
  8. sql_check条件规则:
    • 如果查询有效,工作流将进入run_query
    • 如果查询需要调整且迭代次数小于3,工作流将循环回generate
    • 如果验证失败且迭代次数大于3,工作流将终止(END)。
  9. run_queryEND:
    查询执行后,工作流结束。

上述图表提供了LangGraph节点和边的表示。

Graph Representation of our Agent

在 MLflow 中记录模型

现在我们已经使用LangGraph构建了一个多语言查询引擎,我们可以使用MLflow的Models from Code来记录模型。将模型记录到MLflow中,允许我们将多语言查询引擎视为传统的ML模型,从而实现无缝跟踪、版本控制和打包,以便在各种服务基础设施上进行部署。MLflow的Models from Code策略(我们记录代表模型的代码)与MLflow基于对象的日志记录(其中创建一个模型对象,并将其序列化为pickle或JSON对象进行记录)形成对比。

步骤1:创建我们的代码模型文件

到目前为止,我们已经在名为workflow.py的文件中定义了get_workflow函数。在此步骤中,我们将创建一个新文件sql_model.py,其中引入了SQLGenerator类。此脚本将:

  1. workflow.py导入get_workflow函数。
  2. SQLGenerator类定义为PythonModel,包括一个使用get_workflow函数初始化LangGraph工作流的predict方法。
  3. 使用mlflow.models.set_modelSQLGenerator PythonModel类指定为MLflow感兴趣的模型。
import mlflow
from definitions import REMOTE_SERVER_URI
from workflow import get_workflow

mlflow.set_tracking_uri(REMOTE_SERVER_URI)


class SQLGenerator(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input):
return get_workflow(
model_input["conn"], model_input["cursor"], model_input["vector_store"]
)


mlflow.models.set_model(SQLGenerator())

步骤2:使用MLflow Models from Code功能进行记录

在定义了SQLGenerator自定义Python模型之后,下一步是使用MLflow的Models from Code功能将其记录下来。这包括使用log model的标准API,指定sql_model.py脚本的路径,并使用code_paths参数包含workflow.py作为依赖项。这种方法确保在模型在不同环境或另一台机器上加载时,所有必需的代码文件都可用。

import mlflow
from definitions import (
EXPERIMENT_NAME,
MODEL_ALIAS,
REGISTERED_MODEL_NAME,
REMOTE_SERVER_URI,
)
from mlflow import MlflowClient

client = MlflowClient(tracking_uri=REMOTE_SERVER_URI)

mlflow.set_tracking_uri(REMOTE_SERVER_URI)
mlflow.set_experiment(EXPERIMENT_NAME)

with mlflow.start_run():
logged_model_info = mlflow.pyfunc.log_model(
python_model="sql_model.py",
artifact_path="sql_generator",
registered_model_name=REGISTERED_MODEL_NAME,
code_paths=["workflow.py"],
)

client.set_registered_model_alias(
REGISTERED_MODEL_NAME, MODEL_ALIAS, logged_model_info.registered_model_version
)

在MLflow UI中,存储的模型包括sql_model.py和workflow.py脚本作为run内的artifact。这种从代码记录的功能不仅记录了模型的参数和指标,还捕获了定义其功能的代码。这确保了可观测性、无缝跟踪和直接通过UI进行简单的调试。然而,重要的是要确保绝不将API密钥或凭证等敏感元素硬编码到这些脚本中。由于代码按原样存储,任何包含的敏感信息都可能导致令牌泄露并带来安全风险。相反,应使用环境变量、密钥管理系统或其他安全方法来安全地管理敏感数据。

model_as_code_artifact

main.py中使用记录的多语言查询引擎

记录模型后,可以使用模型URI和标准的mlflow.pyfunc.load_model API从MLflow Tracking Server中将其加载回来。加载模型时,将执行workflow.py脚本以及sql_model.py脚本,确保在调用predict方法时get_workflow函数可用。

通过执行以下代码,我们演示了我们的多语言查询引擎能够执行自然语言到SQL的生成和查询执行。

import os
import logging

import mlflow
from database import setup_database
from definitions import (
EXPERIMENT_NAME,
MODEL_ALIAS,
REGISTERED_MODEL_NAME,
REMOTE_SERVER_URI,
)
from dotenv import load_dotenv
from vector_store import setup_vector_store

mlflow.set_tracking_uri(REMOTE_SERVER_URI)
mlflow.set_experiment(EXPERIMENT_NAME)
mlflow.langchain.autolog()

# Initialize the logger
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
_logger.addHandler(handler)


def main():
# Load environment variables from .env file
load_dotenv()

# Access secrets using os.getenv
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

# Setup database and vector store
conn = setup_database()
cursor = conn.cursor()
vector_store = setup_vector_store()

# Load the model
model_uri = f"models:/{REGISTERED_MODEL_NAME}@{MODEL_ALIAS}"
model = mlflow.pyfunc.load_model(model_uri)
model_input = {"conn": conn, "cursor": cursor, "vector_store": vector_store}
app = model.predict(model_input)

# save image
app.get_graph().draw_mermaid_png(
output_file_path="sql_agent_with_safety_checks.png"
)

# Example user interaction
_logger.info("Welcome to the SQL Assistant!")
while True:
question = input("\nEnter your SQL question (or type 'exit' to quit): ")
if question.lower() == "exit":
break

# Initialize the state with all required keys
initial_state = {
"messages": [("user", question)],
"iterations": 0,
"error": "",
"results": None,
"generation": None,
"no_records_found": False,
"translated_input": "", # Initialize translated_input
}

solution = app.invoke(initial_state)

# Check if an error was set during the safety check
if solution["error"] == "yes":
_logger.info("\nAssistant Message:\n")
_logger.info(solution["messages"][-1][1]) # Display the assistant's message
continue # Skip to the next iteration

# Extract the generated SQL query from solution["generation"]
sql_query = solution["generation"].sql_code
_logger.info("\nGenerated SQL Query:\n")
_logger.info(sql_query)

# Extract and display the query results
if solution.get("no_records_found"):
_logger.info("\nNo records found matching your query.")
elif "results" in solution and solution["results"] is not None:
_logger.info("\nQuery Results:\n")
for row in solution["results"]:
_logger.info(row)
else:
_logger.info("\nNo results returned or query did not execute successfully.")

_logger.info("Goodbye!")


if __name__ == "__main__":
main()

项目文件结构

项目遵循简单的文件结构:

file structure

MLflow Tracing

MLflow Automated Tracing提供了与LangChain、OpenAI、LlamaIndex、DSPy和AutoGen等各种GenAI库的完全自动化的集成。由于我们的AI工作流是使用LangGraph构建的,我们可以通过启用mlflow.langchain.autolog()来激活自动LangChain跟踪。

通过LangChain自动记录,每当调用链的调用API时,跟踪都会自动记录到活动的MLflow实验中。这种无缝集成确保捕获所有交互以进行监控和分析。

在MLflow中查看跟踪

可以通过导航到感兴趣的MLflow实验并点击“Tracing”选项卡来轻松访问跟踪。进入后,选择特定跟踪将提供详细的执行信息。

每个跟踪包括:

  1. 执行图:工作流步骤的可视化。
  2. 输入和输出:在每个步骤处理数据的详细日志。

通过利用MLflow Tracing,我们获得了对整个图执行的细粒度可见性。AI工作流图通常感觉像一个黑匣子,使得调试和理解每个步骤中发生的事情变得困难。然而,只需一行代码即可启用跟踪,MLflow即可提供对工作流的清晰详细的见解,使开发人员能够有效地调试、监控和优化图的每个节点,确保我们的多语言查询引擎保持透明、可审计和可扩展。

mlflow_tracing_gif

结论

在这篇博文中,我们探讨了使用LangGraph和MLflow构建和管理多语言查询引擎的过程。通过将LangGraph的动态AI工作流与MLflow的健壮生命周期管理和跟踪功能集成,我们创建了一个不仅能够提供准确高效的自然语言到SQL的生成和执行,而且还透明且可扩展的系统。