跳到主要内容

从自然语言到 SQL:构建和追踪多语言查询引擎

·阅读 40 分钟
Hugo Carvalho
adidas 机器学习分析师
Joana Ferreira
adidas 机器学习工程师
Rahul Pandey
Adidas 高级解决方案架构师

如果您希望构建一个多语言查询引擎,将自然语言转换为 SQL 生成并执行查询,同时充分利用 MLflow 的功能,那么这篇博文将是您的指南。我们将探讨如何利用代码中的 MLflow 模型来实现 AI 工作流程的无缝追踪和版本控制。此外,我们将深入研究 MLflow 的追踪功能,该功能通过追踪每个中间步骤的输入、输出和元数据,将可观测性引入到 AI 工作流程的许多不同组件中。

简介

SQL 是管理和访问关系数据库中的数据的基本技能。然而,构建复杂的 SQL 查询来回答复杂的数据问题可能具有挑战性且耗时。这种复杂性可能会使充分有效地利用数据变得困难。自然语言到 SQL (NL2SQL) 系统通过提供从自然语言到 SQL 命令的转换来帮助解决这个问题,从而允许非技术人员与数据交互:用户只需用他们舒适的自然语言提问,这些系统将协助他们返回适当的信息。

然而,在创建 NL2SQL 系统时,仍然存在许多问题,例如语义歧义、模式映射或错误处理和用户反馈。因此,在构建此类系统时,我们必须设置一些保护措施,而不是完全依赖 LLM。

在这篇博文中,我们将引导您完成构建多语言查询引擎的过程。该引擎支持多种语言的自然语言输入,根据翻译后的用户输入生成 SQL 查询,并执行查询。让我们来看一个例子:使用包含公司客户、产品和订单信息的数据库,用户可以用任何语言提问,例如“Quantos clientes temos por país?”(葡萄牙语,意为“每个国家/地区有多少客户?”)。AI 工作流程将输入翻译成英语,输出“每个国家/地区有多少客户?”。然后,它验证输入的安全性,检查问题是否可以使用数据库模式回答,生成适当的 SQL 查询(例如,SELECT COUNT(CustomerID) AS NumberOfCustomers, Country FROM Customers GROUP BY Country;),并验证查询以确保不存在有害命令(例如,DROP)。最后,它针对数据库执行查询以检索结果。

我们将首先演示如何利用 LangGraph 的功能来构建动态 AI 工作流程。此工作流程集成了 OpenAI 和外部数据源(例如向量存储和 SQLite 数据库),以处理用户输入、执行安全检查、查询数据库并生成有意义的响应。

在本文中,我们将利用 MLflow 的代码中的模型功能来实现 AI 工作流程的无缝追踪和版本控制。此外,我们将深入研究 MLflow 的追踪功能,该功能旨在通过追踪与每个中间步骤关联的输入、输出和元数据来增强 AI 工作流程的许多不同组件的可观测性。这使得可以轻松识别错误和意外行为,从而提高工作流程的透明度。

先决条件

要设置和运行此项目,请确保安装了以下 Python 包

  • faiss-cpu
  • langchain
  • langchain-core
  • langchain-openai
  • langgraph
  • langchain-community
  • pydantic >=2
  • typing_extensions
  • python-dotenv

此外,需要一个 MLflow 追踪服务器来有效地记录和管理实验、模型和追踪。对于本地设置,请参阅官方 MLflow 文档以获取有关配置简单 MLflow 追踪服务器的说明。

最后,请确保您的 OpenAI API 密钥保存在项目目录中的 .env 文件中。这允许应用程序安全地访问构建 AI 工作流程所需的 OpenAI 服务。.env 文件应包含如下行

OPENAI_API_KEY=your_openai_api_key

使用 LangGraph 的多语言查询引擎

多语言查询引擎利用 LangGraph 库,这是一种 AI 编排工具,旨在为由 LLM 提供支持的应用程序创建有状态的、多代理的和循环的图架构。

与其他 AI 编排器相比,LangGraph 提供了三个核心优势:循环、可控性和持久性。它允许定义带有循环的 AI 工作流程,这对于实现重试机制(例如多语言查询引擎中的 SQL 查询生成重试(如果验证失败,查询将循环返回以重新生成))至关重要。这使得 LangGraph 成为构建我们的多语言查询引擎的理想工具。

LangGraph 的主要功能:

  1. 有状态架构:引擎维护图执行状态的动态快照。此快照充当节点之间的共享资源,从而可以在每次节点执行时进行高效的决策和实时更新。

  2. 多代理设计:AI 工作流程包括在整个工作流程中与 OpenAI 和其他外部工具的多次交互。

  3. 循环图结构:图的循环性质引入了强大的重试机制。此机制通过在需要时循环返回到先前阶段来动态解决故障,从而确保连续的图执行。(此机制的详细信息将在稍后讨论。)

AI 工作流程概述

多语言查询引擎的高级 AI 工作流程由互连的节点和边组成,每个节点和边都代表一个关键阶段

  1. 翻译节点:将用户的输入转换为英语。

  2. 预安全检查:确保用户输入不包含有害或不适当的内容,并且不包含有害的 SQL 命令(例如,DELETEDROP)。

  3. 数据库模式提取:检索目标数据库的模式以了解其结构和可用数据。

  4. 相关性验证:针对数据库模式验证用户的输入,以确保与数据库的上下文对齐。

  5. SQL 查询生成:基于用户的输入和当前数据库模式生成 SQL 查询。

  6. 后安全检查:确保生成的 SQL 查询不包含有害的 SQL 命令(例如,DELETEDROP)。

  7. SQL 查询验证:在回滚安全环境中执行 SQL 查询,以确保其在运行前的有效性。

  8. 动态状态评估:根据当前状态确定下一步。如果 SQL 查询验证失败,它将循环返回到第 5 阶段以重新生成查询。

  9. 查询执行和结果检索:执行 SQL 查询并返回结果(如果是 SELECT 语句)。

重试机制在第 8 阶段引入,系统在其中动态评估当前图状态。具体来说,当 SQL 查询验证节点(第 7 阶段)检测到问题时,状态会触发循环返回到 SQL 生成节点(第 5 阶段)以进行新的 SQL 生成尝试(最多 3 次尝试)。

组件

多语言查询引擎与多个外部组件交互,以将自然语言用户输入转换为 SQL 查询,并以安全可靠的方式执行它们。在本节中,我们将详细了解关键的 AI 工作流程组件:OpenAI、向量存储、SQLite 数据库和 SQL 生成链。

OpenAI

OpenAI,更具体地说,gpt-4o-mini 语言模型,在工作流程的多个阶段中起着至关重要的作用。它为以下方面提供了所需的智能

  1. 翻译:将用户输入翻译成英语。如果文本已经是英语,它只会重复输入。

  2. 安全检查:分析用户输入以确保其不包含有害或不适当的内容。

  3. 相关性检查:评估用户的提问是否与数据库模式相关。

  4. SQL 生成:基于用户输入、SQL 生成文档和数据库模式生成有效且可执行的 SQL 查询。

有关 OpenAI 实现的详细信息将在后面的节点说明部分中提供。

FAISS 向量存储

为了构建能够生成准确且可执行的 SQL 查询的有效的自然语言到 SQL 引擎,我们利用了 Langchain 的 FAISS 向量存储功能。此设置允许系统从先前存储在向量数据库中的 W3Schools SQL 文档中搜索和提取 SQL 查询生成指南,从而提高 SQL 查询生成的成功率。

出于演示目的,我们使用 FAISS,这是一种内存向量存储,向量直接存储在 RAM 中。这提供了快速访问,但也意味着数据不会在运行之间持久保存。为了获得更具可扩展性的解决方案,该解决方案能够跨多个项目存储和共享嵌入,我们建议使用 AWS OpenSearchVertex AI 向量搜索Azure 向量搜索Mosaic AI 向量搜索 等替代方案。这些基于云的解决方案提供持久存储、自动扩展以及与其他云服务的无缝集成,使其非常适合大规模应用程序。

步骤 1:加载 SQL 文档

创建带有 SQL 查询生成指南的 FAISS 向量存储的第一步是使用 LangChain 的 RecursiveUrlLoaderW3Schools SQL 页面加载 SQL 文档。此工具检索文档,允许我们将其用作我们引擎的知识库。

步骤 2:将文本拆分为可管理的块

加载的 SQL 文档是一段很长的文本,因此很难被 LLM 有效地提取。为了解决这个问题,下一步涉及使用 Langchain 的 RecursiveCharacterTextSplitter 将文本拆分为更小、更可管理的块。通过将文本拆分为 500 个字符的块,并具有 50 个字符的重叠,我们确保语言模型具有足够的上下文,同时最大限度地减少丢失跨块的重要信息的风险。split_text 方法应用此拆分过程,将结果片段存储在名为“documents”的列表中。

步骤 3:生成嵌入模型

第三步是创建一个模型,该模型将这些块转换为嵌入(每个文本块的向量化数值表示)。嵌入使系统能够比较块和用户输入之间的相似性,从而有助于检索 SQL 查询生成的最相关匹配项。

步骤 4:在 FAISS 向量存储中创建和存储嵌入

最后,我们使用 FAISS 创建和存储嵌入。FAISS.from_texts 方法获取所有块,计算它们的嵌入并将它们存储在高速可搜索向量数据库中。此可搜索数据库允许引擎有效地检索相关的 SQL 指南,从而显著提高可执行 SQL 查询生成的成功率。

import logging
import os

from bs4 import BeautifulSoup as Soup

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings


def setup_vector_store(logger: logging.Logger):
"""Setup or load the vector store."""
if not os.path.exists("data"):
os.makedirs("data")

vector_store_dir = "data/vector_store"

if os.path.exists(vector_store_dir):
# Load the vector store from disk
logger.info("Loading vector store from disk...")
vector_store = FAISS.load_local(
vector_store_dir,
OpenAIEmbeddings(),
allow_dangerous_deserialization=True,
)
else:
logger.info("Creating new vector store...")
# Load SQL documentation
url = "https://w3schools.org.cn/sql/"
loader = RecursiveUrlLoader(
url=url, max_depth=2, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()

# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
)

documents = []
for doc in docs:
splits = text_splitter.split_text(doc.page_content)
for i, split in enumerate(splits):
documents.append(
{
"content": split,
"metadata": {"source": doc.metadata["source"], "chunk": i},
}
)

# Compute embeddings and create vector store
embedding_model = OpenAIEmbeddings()
vector_store = FAISS.from_texts(
[doc["content"] for doc in documents],
embedding_model,
metadatas=[doc["metadata"] for doc in documents],
)

# Save the vector store to disk
vector_store.save_local(vector_store_dir)
logger.info("Vector store created and saved to disk.")

return vector_store

SQLite 数据库

SQLite 数据库是多语言查询引擎的关键组件,充当结构化数据存储库。SQLite 提供了一个轻量级、快速且独立的关联数据库引擎,无需服务器设置或安装。它的小巧尺寸(小于 500KB)和零配置特性使其非常易于使用,而其平台无关的数据库格式确保了跨不同系统的无缝可移植性。作为本地磁盘数据库,SQLite 是避免设置 MySQL 或 PostgreSQL 的复杂性的理想选择,同时仍然提供可靠、功能齐全且性能出色的 SQL 引擎。

SQLite 数据库通过以下方式支持高效的 SQL 查询生成、验证和执行

  1. 模式提取:为用户的输入上下文验证(第 4 阶段)和可执行 SQL 查询生成(第 5 阶段)提供模式信息。

  2. 查询执行:在验证阶段(第 7 阶段)和查询执行阶段(第 9 阶段)的回滚安全环境中执行 SQL 查询,获取 SELECT 语句的结果并提交其他查询类型的更改。

SQLite 数据库初始化

当 AI 工作流程初始化时,数据库使用 setup_database 函数进行初始化。此过程涉及

  1. 设置 SQLite 数据库连接:建立与 SQLite 数据库的连接,从而实现数据交互。

  2. 表创建:定义并创建 AI 工作流程所需的数据库表。

  3. 数据填充:使用示例数据填充表以支持查询执行和验证阶段。

import logging
import os

import sqlite3


def create_connection(db_file="data/database.db"):
"""Create a database connection to the SQLite database."""
conn = sqlite3.connect(db_file)
return conn


def create_tables(conn):
"""Create tables in the database."""
cursor = conn.cursor()
# Create Customers table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Customers (
CustomerID INTEGER PRIMARY KEY,
CustomerName TEXT,
ContactName TEXT,
Address TEXT,
City TEXT,
PostalCode TEXT,
Country TEXT
)
"""
)

# Create Orders table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Orders (
OrderID INTEGER PRIMARY KEY,
CustomerID INTEGER,
OrderDate TEXT,
FOREIGN KEY (CustomerID) REFERENCES Customers (CustomerID)
)
"""
)

# Create OrderDetails table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS OrderDetails (
OrderDetailID INTEGER PRIMARY KEY,
OrderID INTEGER,
ProductID INTEGER,
Quantity INTEGER,
FOREIGN KEY (OrderID) REFERENCES Orders (OrderID),
FOREIGN KEY (ProductID) REFERENCES Products (ProductID)
)
"""
)

# Create Products table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS Products (
ProductID INTEGER PRIMARY KEY,
ProductName TEXT,
Price REAL
)
"""
)

conn.commit()


def populate_tables(conn):
"""Populate tables with sample data if they are empty."""
cursor = conn.cursor()

# Populate Customers table if empty
cursor.execute("SELECT COUNT(*) FROM Customers")
if cursor.fetchone()[0] == 0:
customers = []
for i in range(1, 51):
customers.append(
(
i,
f"Customer {i}",
f"Contact {i}",
f"Address {i}",
f"City {i % 10}",
f"{10000 + i}",
f"Country {i % 5}",
)
)
cursor.executemany(
"""
INSERT INTO Customers (CustomerID, CustomerName, ContactName, Address, City, PostalCode, Country)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
customers,
)

# Populate Products table if empty
cursor.execute("SELECT COUNT(*) FROM Products")
if cursor.fetchone()[0] == 0:
products = []
for i in range(1, 51):
products.append((i, f"Product {i}", round(10 + i * 0.5, 2)))
cursor.executemany(
"""
INSERT INTO Products (ProductID, ProductName, Price)
VALUES (?, ?, ?)
""",
products,
)

# Populate Orders table if empty
cursor.execute("SELECT COUNT(*) FROM Orders")
if cursor.fetchone()[0] == 0:
orders = []
from datetime import datetime, timedelta

base_date = datetime(2023, 1, 1)
for i in range(1, 51):
order_date = base_date + timedelta(days=i)
orders.append(
(
i,
i % 50 + 1, # CustomerID between 1 and 50
order_date.strftime("%Y-%m-%d"),
)
)
cursor.executemany(
"""
INSERT INTO Orders (OrderID, CustomerID, OrderDate)
VALUES (?, ?, ?)
""",
orders,
)

# Populate OrderDetails table if empty
cursor.execute("SELECT COUNT(*) FROM OrderDetails")
if cursor.fetchone()[0] == 0:
order_details = []
for i in range(1, 51):
order_details.append(
(
i,
i % 50 + 1, # OrderID between 1 and 50
i % 50 + 1, # ProductID between 1 and 50
(i % 5 + 1) * 2, # Quantity between 2 and 10
)
)
cursor.executemany(
"""
INSERT INTO OrderDetails (OrderDetailID, OrderID, ProductID, Quantity)
VALUES (?, ?, ?, ?)
""",
order_details,
)

conn.commit()


def setup_database(logger: logging.Logger):
"""Setup the database and return the connection."""
db_file = "data/database.db"
if not os.path.exists("data"):
os.makedirs("data")

db_exists = os.path.exists(db_file)

conn = create_connection(db_file)

if not db_exists:
logger.info("Setting up the database...")
create_tables(conn)
populate_tables(conn)
else:
logger.info("Database already exists. Skipping setup.")

return conn

SQL 生成链

SQL 生成链 (sql_gen_chain) 是我们工作流程中自动 SQL 查询生成的支柱。此链利用 LangChain 的模块化功能和 OpenAI 的高级自然语言处理将用户问题转换为精确且可执行的 SQL 查询。

核心功能:

  • 提示驱动生成:从一个经过深思熟虑的提示开始,该提示集成了数据库模式和文档片段,确保查询在上下文中是准确的。

  • 结构化响应:以预定义的格式传递输出,包括

    • 查询目的的描述

    • 准备执行的相应 SQL 代码

  • 适应性和可靠性:使用 gpt-4o-mini 进行稳健、一致的查询生成,从而最大限度地减少手动工作和错误。

此链是我们工作流程中的关键组件,可以实现 SQL 查询生成与下游流程的无缝集成,确保准确性并显著提高效率。

from pydantic import BaseModel, Field

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

class SQLQuery(BaseModel):
"""Schema for SQL query solutions to questions."""
description: str = Field(description="Description of the SQL query")
sql_code: str = Field(description="The SQL code block")

def get_sql_gen_chain():
"""Set up the SQL generation chain."""
sql_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a SQL assistant with expertise in SQL query generation. \n
Answer the user's question based on the provided documentation snippets and the database schema provided below. Ensure any SQL query you provide is valid and executable. \n
Structure your answer with a description of the query, followed by the SQL code block. Here are the documentation snippets:\n{retrieved_docs}\n\nDatabase Schema:\n{database_schema}""",
),
("placeholder", "{messages}"),
]
)

# Initialize the OpenAI LLM
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")

# Create the code generation chain
sql_gen_chain = sql_gen_prompt | llm.with_structured_output(SQLQuery)

return sql_gen_chain

工作流程设置和初始化

在深入研究工作流程节点之前,至关重要的是设置必要的组件并定义工作流程的结构。本节将解释基本库、日志记录和自定义 GraphState 类的初始化,以及主工作流程编译函数。

定义 GraphState

GraphState 类是一个自定义 TypedDict,用于在工作流程进行时维护状态信息。它充当节点之间的共享数据结构,从而确保连续性和一致性。关键字段包括

  • error:跟踪是否发生错误。
  • messages:存储用户和系统消息的列表。
  • generation:保存生成的 SQL 查询。
  • iterations:跟踪出现错误时的重试次数。
  • results:存储 SQL 执行结果(如果有)。
  • no_records_found:标记查询是否未返回任何记录。
  • translated_input:包含用户翻译后的输入。
  • database_schema:维护数据库模式以进行上下文验证。
import logging
import re
from typing import List, Optional

from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from sql_generation import get_sql_gen_chain
from typing_extensions import TypedDict

# Initialize the logger
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
_logger.addHandler(handler)

class GraphState(TypedDict):
error: str # Tracks if an error has occurred
messages: List # List of messages (user input and assistant messages)
generation: Optional[str] # Holds the generated SQL query
iterations: int # Keeps track of how many times the workflow has retried
results: Optional[List] # Holds the results of SQL execution
no_records_found: bool # Flag for whether any records were found in the SQL result
translated_input: str # Holds the translated user input
database_schema: str # Holds the extracted database schema for context checking
工作流程编译函数

主函数 get_workflow 负责定义和编译工作流程。关键组件包括

  • conncursor:用于数据库连接和查询执行。
  • vector_store:用于上下文检索的向量数据库。
  • max_iterations:设置重试次数限制以防止无限循环。
  • sql_gen_chain:从 sql_generation 检索 SQL 生成链,以基于上下文输入生成 SQL 查询。
  • ChatOpenAI:初始化 OpenAI gpt-4o-mini 模型以执行安全检查和查询翻译等任务。
def get_workflow(conn, cursor, vector_store):
"""Define and compile the LangGraph workflow."""

# Max iterations: defines how many times the workflow should retry in case of errors
max_iterations = 3

# SQL generation chain: this is a chain that will generate SQL based on retrieved docs
sql_gen_chain = get_sql_gen_chain()

# Initialize OpenAI LLM for translation and safety checks
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")

# Define the individual nodes of the workflow

此函数充当使用 StateGraph 创建完整工作流程的入口点。工作流程中的各个节点将在后续部分中定义和连接。

节点说明

1. 翻译输入

translate_input 节点将用户查询翻译成英语,以标准化处理并确保与下游节点的兼容性。在 AI 工作流程中,将用户输入翻译为第一步可以确保任务分离并提高可观测性。任务分离通过将翻译与其他下游任务(如用户输入安全验证和 SQL 生成)隔离来简化工作流程。提高的可观测性在 MLflow 中提供了清晰的追踪,从而更容易调试和监控该过程。

  • 示例
    • 输入:“Quantos pedidos foram realizados em Novembro?”
    • 已翻译:“11 月份下了多少订单?”
    • 输入:“Combien de ventes avons-nous enregistrées en France?”
    • 已翻译:“我们在法国记录了多少销量?”
  • 代码
def translate_input(state: GraphState) -> GraphState:
"""
Translates user input to English using an LLM. If the input is already in English,
it is returned as is. This ensures consistent input for downstream processing.

Args:
state (GraphState): The current graph state containing user messages.

Returns:
GraphState: The updated state with the translated input.
"""
_logger.info("Starting translation of user input to English.")
messages = state["messages"]
user_input = messages[-1][1] # Get the latest user input

# Translation prompt for the model
translation_prompt = f"""
Translate the following text to English. If the text is already in English, repeat it exactly without any additional explanation.

Text:
{user_input}
"""

# Call the OpenAI LLM to translate the text
translated_response = llm.invoke(translation_prompt)
translated_text = translated_response.content.strip() # Access the 'content' attribute and strip any extra spaces

# Update state with the translated input
state["translated_input"] = translated_text
_logger.info("Translation completed successfully. Translated input: %s", translated_text)

return state
2. 预安全检查

pre_safety_check 节点确保在用户输入中尽早检测到不允许的 SQL 操作和不适当的内容。虽然稍后会在工作流程中再次检查有害 SQL 命令(例如,CREATEDELETEDROPINSERTUPDATE)(特别是在生成 SQL 查询之后),但此预安全检查对于在输入阶段识别潜在问题至关重要。通过这样做,它可以防止不必要的计算,并向用户提供即时反馈。

虽然使用不允许列表来防范有害 SQL 操作提供了一种快速保护免受破坏性查询的方法,但当处理像 T-SQL 这样的复杂 SQL 后端时,维护全面的不允许列表可能很难管理。另一种方法是采用允许列表,将查询限制为仅安全操作(例如,SELECTJOIN)。这种方法通过缩小允许的操作范围而不是尝试阻止每个有风险的命令来确保更强大的解决方案。

为了实现企业级解决方案,该项目可以利用像 Unity Catalog 这样的框架,该框架为管理安全相关功能(例如 AI 工作流程的 pre_safety_check)提供了一种集中且强大的方法。通过在此类框架中注册和管理可重用功能,您可以跨所有 AI 工作流程强制执行一致且可靠的行为,从而增强安全性和可扩展性。

此外,该节点还利用 LLM 来分析输入中是否存在冒犯性或不适当的内容。如果检测到不安全查询或不适当的内容,则状态会更新为错误标志,并提供透明的反馈,从而尽早保护工作流程免受恶意或破坏性元素的影响。

  • 示例
  1. 不允许的操作

    • 输入: “DROP TABLE customers;”

    • 响应: “您的查询包含不允许的 SQL 操作,无法处理。”

    • 输入: _“SELECT _ FROM orders;”*_

    • 响应: “查询已允许。”

  2. 不适当的内容

    • 输入: “显示我客户姓名像‘白痴约翰’的订单;”
    • 响应: “您的查询包含不适当的内容,无法处理。”
    • 输入: “按地区查找总销售额。”
    • 响应: “输入可以安全处理。”
  • 代码
def pre_safety_check(state: GraphState) -> GraphState:
"""
Perform safety checks on the user input to ensure that no dangerous SQL operations
or inappropriate content is present. The function checks for SQL operations like
DELETE, DROP, and others, and also evaluates the input for toxic or unsafe content.

Args:
state (GraphState): The current graph state containing the translated user input.

Returns:
GraphState: The updated state with error status and messages if any issues are found.
"""
_logger.info("Performing safety check.")
translated_input = state["translated_input"]
messages = state["messages"]
error = "no"

# List of disallowed SQL operations (e.g., DELETE, DROP)
disallowed_operations = ['CREATE', 'DELETE', 'DROP', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
pattern = re.compile(r'\b(' + '|'.join(disallowed_operations) + r')\b', re.IGNORECASE)

# Check if the input contains disallowed SQL operations
if pattern.search(translated_input):
_logger.warning("Input contains disallowed SQL operations. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your query contains disallowed SQL operations and cannot be processed.")]
else:
# Check if the input contains inappropriate content
safety_prompt = f"""
Analyze the following input for any toxic or inappropriate content.

Respond with only "safe" or "unsafe", and nothing else.

Input:
{translated_input}
"""
safety_invoke = llm.invoke(safety_prompt)
safety_response = safety_invoke.content.strip().lower() # Access the 'content' attribute and strip any extra spaces

if safety_response == "safe":
_logger.info("Input is safe to process.")
else:
_logger.warning("Input contains inappropriate content. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your query contains inappropriate content and cannot be processed.")]

# Update state with error status and messages
state["error"] = error
state["messages"] = messages

return state
3. 模式提取

schema_extract 节点通过查询元数据来动态检索数据库模式,包括表名和列详细信息。格式化的模式存储在状态中,从而可以在验证用户查询的同时适应当前数据库结构。

  • 示例
    • 输入:模式提取请求。
      模式输出
      • Customers(CustomerID (INTEGER), CustomerName (TEXT), ContactName (TEXT), Address (TEXT), City (TEXT), PostalCode (TEXT), Country (TEXT))
      • Orders(OrderID (INTEGER), CustomerID (INTEGER), OrderDate (TEXT))
      • OrderDetails(OrderDetailID (INTEGER), OrderID (INTEGER), ProductID (INTEGER), Quantity (INTEGER))
      • Products(ProductID (INTEGER), ProductName (TEXT), Price (REAL))
  • 代码
def schema_extract(state: GraphState) -> GraphState:
"""
Extracts the database schema, including all tables and their respective columns,
from the connected SQLite database. This function retrieves the list of tables and
iterates through each table to gather column definitions (name and data type).

Args:
state (GraphState): The current graph state, which will be updated with the database schema.

Returns:
GraphState: The updated state with the extracted database schema.
"""
_logger.info("Extracting database schema.")

# Extract the schema from the database
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema_details = []

# Loop through each table and retrieve column information
for table_name_tuple in tables:
table_name = table_name_tuple[0]
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()

# Format column definitions
column_defs = ', '.join([f"{col[1]} ({col[2]})" for col in columns])
schema_details.append(f"- {table_name}({column_defs})")

# Save the schema in the state
database_schema = '\n'.join(schema_details)
state["database_schema"] = database_schema
_logger.info(f"Database schema extracted:\n{database_schema}")

return state
4. 上下文检查

context_check 节点通过将用户查询与提取的数据库模式进行比较来验证用户查询,以确保对齐和相关性。与模式不对应的查询被标记为不相关,从而防止资源浪费并使用户能够提供查询重构的反馈。

  • 示例
    • 输入:“平均订单价值是多少?”
      模式匹配:输入与数据库模式相关。
    • 输入:“显示我库存表中的数据。”
      响应:“您的提问与数据库无关,无法处理。”
  • 代码
def context_check(state: GraphState) -> GraphState:
"""
Checks whether the user's input is relevant to the database schema by comparing
the user's question with the database schema. Uses a language model to determine if
the question can be answered using the provided schema.

Args:
state (GraphState): The current graph state, which contains the translated input
and the database schema.

Returns:
GraphState: The updated state with error status and messages if the input is irrelevant.
"""
_logger.info("Performing context check.")

# Extract relevant data from the state
translated_input = state["translated_input"]
messages = state["messages"]
error = "no"
database_schema = state["database_schema"] # Get the schema from the state

# Use the LLM to determine if the input is relevant to the database schema
context_prompt = f"""
Determine whether the following user input is a question that can be answered using the database schema provided below.

Respond with only "relevant" if the input is relevant to the database schema, or "irrelevant" if it is not.

User Input:
{translated_input}

Database Schema:
{database_schema}
"""

# Call the LLM for context check
llm_invoke = llm.invoke(context_prompt)
llm_response = llm_invoke.content.strip().lower() # Access the 'content' attribute and strip any extra spaces and lower case

# Process the response from the LLM
if llm_response == "relevant":
_logger.info("Input is relevant to the database schema.")
else:
_logger.info("Input is not relevant. Halting the workflow.")
error = "yes"
messages += [("assistant", "Your question is not related to the database and cannot be processed.")]

# Update the state with error and messages
state["error"] = error
state["messages"] = messages

return state
5. 生成

generate 节点通过从向量存储中检索相关文档并利用预定义的 SQL 生成链,从自然语言输入构造 SQL 查询。它使查询与用户的意图和模式上下文对齐,并使用生成的 SQL 及其描述更新状态。

  • 示例
    • 输入:“查找总销售额。”
      生成的 SQL:“SELECT SUM(Products.Price * OrderDetails.Quantity) AS TotalSales FROM OrderDetails LEFT JOIN Products ON OrderDetails.ProductID = Products.ProductID;”
    • 输入:“列出纽约的所有客户。”
      生成的 SQL:“SELECT name FROM customers WHERE location = 'New York';”
  • 代码
def generate(state: GraphState) -> GraphState:
"""
Generates an SQL query based on the user's input. The node retrieves relevant documents from
the vector store and uses a generation chain to produce an SQL query.

Args:
state (GraphState): The current graph state, which contains the translated input and
other relevant data such as messages and iteration count.

Returns:
GraphState: The updated state with the generated SQL query and related messages.
"""
_logger.info("Generating SQL query.")

# Extract relevant data from the state
messages = state["messages"]
iterations = state["iterations"]
translated_input = state["translated_input"]
database_schema = state["database_schema"]

# Retrieve relevant documents from the vector store based on the translated user input
docs = vector_store.similarity_search(translated_input, k=4)
retrieved_docs = "\n\n".join([doc.page_content for doc in docs])

# Generate the SQL query using the SQL generation chain
sql_solution = sql_gen_chain.invoke(
{
"retrieved_docs": retrieved_docs,
"database_schema": database_schema,
"messages": [("user", translated_input)],
}
)

# Save the generated SQL query in the state
messages += [
(
"assistant",
f"{sql_solution.description}\nSQL Query:\n{sql_solution.sql_code}",
)
]
iterations += 1

# Log the generated SQL query
_logger.info("Generated SQL query:\n%s", sql_solution.sql_code)

# Update the state with the generated SQL query and updated message list
state["generation"] = sql_solution
state["messages"] = messages
state["iterations"] = iterations

return state
6. 后安全检查

post_safety_check 节点通过对有害 SQL 命令执行最终验证来确保生成的 SQL 查询是安全的。虽然早期的预安全检查识别用户输入中不允许的操作,但此后安全检查验证生成后生成的 SQL 查询是否符合安全指南。这种两步法确保即使在查询生成过程中意外引入了不允许的操作,也可以捕获并标记它们。如果检测到不安全查询,则节点会停止工作流程,使用错误标志更新状态,并向用户提供反馈。

  • 示例
  1. 不允许的操作
    • 生成的查询: “DROP TABLE orders;”
    • 响应: “生成的 SQL 查询包含不允许的 SQL 操作:DROP,无法处理。”
    • 生成的查询: “SELECT name FROM customers;”
    • 响应: “查询有效。”
  • 代码
def post_safety_check(state: GraphState) -> GraphState:
"""
Perform safety checks on the generated SQL query to ensure that it doesn't contain disallowed operations
such as CREATE, DELETE, DROP, etc. This node checks the SQL query generated earlier in the workflow.

Args:
state (GraphState): The current graph state containing the generated SQL query.

Returns:
GraphState: The updated state with error status and messages if any issues are found.
"""
_logger.info("Performing post-safety check on the generated SQL query.")

# Retrieve the generated SQL query from the state
sql_solution = state.get("generation", {})
sql_query = sql_solution.get("sql_code", "").strip()
messages = state["messages"]
error = "no"

# List of disallowed SQL operations
disallowed_operations = ['CREATE', 'DELETE', 'DROP', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
pattern = re.compile(r'\b(' + '|'.join(disallowed_operations) + r')\b', re.IGNORECASE)

# Check if the generated SQL query contains disallowed SQL operations
found_operations = pattern.findall(sql_query)
if found_operations:
_logger.warning(
"Generated SQL query contains disallowed SQL operations: %s. Halting the workflow.",
", ".join(set(found_operations))
)
error = "yes"
messages += [("assistant", f"The generated SQL query contains disallowed SQL operations: {', '.join(set(found_operations))} and cannot be processed.")]
else:
_logger.info("Generated SQL query passed the safety check.")

# Update state with error status and messages
state["error"] = error
state["messages"] = messages

return state
7. SQL 检查

sql_check 节点通过在事务保存点中执行生成的 SQL 查询来确保其安全且语法有效。验证后会回滚任何更改,并标记错误并提供详细的反馈以维护查询完整性。

  • 示例
    • 输入 SQL:“SELECT name FROM customers WHERE city = 'New York';”
      验证:查询有效。
    • 输入 SQL:“SELECT MONTH(date) AS month, SUM(total) AS total_sales FROM orders GROUP BY MONTH(date);”
      响应:“您的 SQL 查询执行失败:没有这样的函数:MONTH。”
  • 代码
def sql_check(state: GraphState) -> GraphState:
"""
Validates the generated SQL query by attempting to execute it on the database.
If the query is valid, the changes are rolled back to ensure no data is modified.
If there is an error during execution, the error is logged and the state is updated accordingly.

Args:
state (GraphState): The current graph state, which contains the generated SQL query
and the messages to communicate with the user.

Returns:
GraphState: The updated state with error status and messages if the query is invalid.
"""
_logger.info("Validating SQL query.")

# Extract relevant data from the state
messages = state["messages"]
sql_solution = state["generation"]
error = "no"

sql_code = sql_solution.sql_code.strip()

try:
# Start a savepoint for the transaction to allow rollback
conn.execute('SAVEPOINT sql_check;')
# Attempt to execute the SQL query
cursor.execute(sql_code)
# Roll back to the savepoint to undo any changes
conn.execute('ROLLBACK TO sql_check;')
_logger.info("SQL query validation: success.")
except Exception as e:
# Roll back in case of error
conn.execute('ROLLBACK TO sql_check;')
_logger.error("SQL query validation failed. Error: %s", e)
messages += [("user", f"Your SQL query failed to execute: {e}")]
error = "yes"

# Update the state with the error status
state["error"] = error
state["messages"] = messages

return state
8. 运行查询

run_query 节点执行验证的 SQL 查询,连接到数据库以检索结果。它使用查询输出来更新状态,确保数据的格式适合进一步分析或报告,同时实现强大的错误处理。

  • 示例
    • 输入 SQL:“SELECT COUNT(*) FROM Customers WHERE City = 'New York';”
      查询结果:“(0,)”
    • 输入 SQL:_“SELECT SUM(Products.Price * OrderDetails.Quantity) AS TotalSales FROM OrderDetails LEFT JOIN Products ON OrderDetails.ProductID = Products.ProductID;”*_
      查询结果:_“(6925.0,)”_
  • 代码
def run_query(state: GraphState) -> GraphState:
"""
Executes the generated SQL query on the database and retrieves the results if it is a SELECT query.
For non-SELECT queries, commits the changes to the database. If no records are found for a SELECT query,
the `no_records_found` flag is set to True.

Args:
state (GraphState): The current graph state, which contains the generated SQL query and other relevant data.

Returns:
GraphState: The updated state with the query results, or a flag indicating if no records were found.
"""
_logger.info("Running SQL query.")

# Extract the SQL query from the state
sql_solution = state["generation"]
sql_code = sql_solution.sql_code.strip()
results = None
no_records_found = False # Flag to indicate no records found

try:
# Execute the SQL query
cursor.execute(sql_code)

# For SELECT queries, fetch and store results
if sql_code.upper().startswith("SELECT"):
results = cursor.fetchall()
if not results:
no_records_found = True
_logger.info("SQL query execution: success. No records found.")
else:
_logger.info("SQL query execution: success.")
else:
# For non-SELECT queries, commit the changes
conn.commit()
_logger.info("SQL query execution: success. Changes committed.")
except Exception as e:
_logger.error("SQL query execution failed. Error: %s", e)

# Update the state with results and flag for no records found
state["results"] = results
state["no_records_found"] = no_records_found

return state
决策步骤:确定下一步操作

decide_next_step 函数充当工作流程中的控制点,根据当前状态决定接下来应采取的操作。它评估 error 状态和目前执行的迭代次数,以确定是否应运行查询、是否应完成工作流程,或者系统是否应重试生成 SQL 查询。

  • 过程

    • 如果没有错误 (error == "no"),系统将继续运行 SQL 查询。
    • 如果已达到最大迭代次数 (max_iterations),则工作流程结束。
    • 如果发生错误且尚未达到最大迭代次数,系统将重试查询生成。
  • 工作流程决策示例

    • 没有错误,继续查询:如果在之前的步骤中没有发现错误,则工作流程会前进到运行查询。
    • 已达到最大迭代次数,结束工作流程:如果工作流程已经尝试了设置的次数 (max_iterations),则会终止。
    • 检测到错误,重试 SQL 生成:如果发生错误且系统尚未达到重试限制,则会尝试重新生成 SQL 查询。
  • 代码

def decide_next_step(state: GraphState) -> str:
"""
Determines the next step in the workflow based on the current state, including whether the query
should be run, the workflow should be finished, or if the query generation needs to be retried.

Args:
state (GraphState): The current graph state, which contains error status and iteration count.

Returns:
str: The next step in the workflow, which can be "run_query", "generate", or END.
"""
_logger.info("Deciding next step based on current state.")

error = state["error"]
iterations = state["iterations"]

if error == "no":
_logger.info("Error status: no. Proceeding with running the query.")
return "run_query"
elif iterations >= max_iterations:
_logger.info("Maximum iterations reached. Ending the workflow.")
return END
else:
_logger.info("Error detected. Retrying SQL query generation.")
return "generate"

工作流程编排和条件逻辑

为了定义系统中任务的编排,我们使用 StateGraph 类构建工作流程图。每个任务都表示为一个节点,任务之间的转换定义为边。条件边用于根据工作流程的状态控制流程。

def get_workflow(conn, cursor, vector_store):
"""Define and compile the LangGraph workflow."""

# Max iterations: defines how many times the workflow should retry in case of errors
max_iterations = 3

# SQL generation chain: this is a chain that will generate SQL based on retrieved docs
sql_gen_chain = get_sql_gen_chain()

# Initialize OpenAI LLM for translation and safety checks
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")

# Define the individual nodes of the workflow
... # Insert nodes code defined above here

# Build the workflow graph
workflow = StateGraph(GraphState)

# Define workflow nodes
workflow.add_node("translate_input", translate_input) # Translate user input to structured format
workflow.add_node("pre_safety_check", pre_safety_check) # Perform a pre-safety check on input
workflow.add_node("schema_extract", schema_extract) # Extract the database schema
workflow.add_node("context_check", context_check) # Validate input relevance to context
workflow.add_node("generate", generate) # Generate SQL query
workflow.add_node("post_safety_check", post_safety_check) # Perform a post-safety check on generated SQL query
workflow.add_node("sql_check", sql_check) # Validate the generated SQL query
workflow.add_node("run_query", run_query) # Execute the SQL query

# Define workflow edges
workflow.add_edge(START, "translate_input") # Start at the translation step
workflow.add_edge("translate_input", "pre_safety_check") # Move to safety checks

# Conditional edge after safety check
workflow.add_conditional_edges(
"pre_safety_check", # Start at the pre_safety_check node
lambda state: "schema_extract" if state["error"] == "no" else END, # Decide next step
{"schema_extract": "schema_extract", END: END} # Map states to nodes
)

workflow.add_edge("schema_extract", "context_check") # Proceed to context validation

# Conditional edge after context check
workflow.add_conditional_edges(
"context_check", # Start at the context_check node
lambda state: "generate" if state["error"] == "no" else END, # Decide next step
{"generate": "generate", END: END}
)

workflow.add_edge("generate", "post_safety_check") # Proceed to post-safety check

# Conditional edge after post-safety check
workflow.add_conditional_edges(
"post_safety_check", # Start at the post_safety_check node
lambda state: "sql_check" if state["error"] == "no" else END, # If no error, proceed to sql_check, else END
{"sql_check": "sql_check", END: END},
)

# Conditional edge after SQL validation
workflow.add_conditional_edges(
"sql_check", # Start at the sql_check node
decide_next_step, # Function to determine the next step
{
"run_query": "run_query", # If SQL is valid, execute the query
"generate": "generate", # If retry is needed, go back to generation
END: END # Otherwise, terminate the workflow
}
)

workflow.add_edge("run_query", END) # Final step is to end the workflow

# Compile and return the workflow application
app = workflow.compile()

return app
  1. 从“开始”到 translate_input:
    工作流程首先将用户输入翻译成结构化格式。
  2. translate_inputpre_safety_check:
    翻译完成后,工作流程将继续检查输入的安全性。
  3. pre_safety_check 条件规则:
    • 如果输入通过预安全检查 (state["error"] == "no"),则工作流程将移动到 schema_extract
    • 如果输入未能通过预安全检查,则工作流程终止 (END)。
  4. schema_extractcontext_check:
    提取模式,然后工作流程验证输入与数据库上下文的相关性。
  5. context_check 条件规则:
    • 如果输入是相关的 (state["error"] == "no"),则工作流程将移动到 generate
    • 如果不是,则工作流程终止 (END)。
  6. generatepost_safety_check:
    工作流程生成 SQL 查询并将其发送以进行验证。\
  7. post_safety_check 条件规则:- 如果输入通过后安全检查 (state["error"] == "no"),则工作流程将移动到 sql_check。- 如果输入未能通过后安全检查,则工作流程终止 (END)。
  8. sql_check 条件规则:
    • 如果查询有效,则工作流程将继续执行 run_query
    • 如果查询需要调整且迭代次数小于 3,则工作流程会循环回到 generate
    • 如果验证失败且迭代次数大于 3,则工作流程终止 (END)。
  9. run_queryEND:
    执行查询后,工作流程结束。

上面的模式提供了 LangGraph 节点和边的表示

Graph Representation of our Agent

在 MLflow 中记录模型

现在我们已经使用 LangGraph 构建了多语言查询引擎,我们准备使用 MLflow 的 代码中的模型记录该模型。将模型记录到 MLflow 中允许我们将多语言查询引擎视为传统的 ML 模型,从而可以跨各种服务基础设施进行部署的无缝追踪、版本控制和打包。MLflow 的代码中的模型策略(我们在其中记录表示模型的代码)与基于 MLflow 对象的日志记录(在其中创建模型对象,序列化并记录为 pickle 或 JSON 对象)形成对比。

步骤 1:从代码文件创建我们的模型

到目前为止,我们已经在名为 workflow.py 的文件中定义了 get_workflow 函数。在此步骤中,我们将创建一个新文件 sql_model.py,该文件引入了 SQLGenerator 类。此脚本将

  1. workflow.py 导入 get_workflow 函数。
  2. SQLGenerator 类定义为 PythonModel,包括使用 get_workflow 函数初始化 LangGraph 工作流程的 predict 方法。
  3. 使用 mlflow.models.set_modelSQLGenerator PythonModel 类指定为 MLflow 感兴趣的模型。
import mlflow
from definitions import REMOTE_SERVER_URI
from workflow import get_workflow

mlflow.set_tracking_uri(REMOTE_SERVER_URI)


class SQLGenerator(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input):
return get_workflow(
model_input["conn"], model_input["cursor"], model_input["vector_store"]
)


mlflow.models.set_model(SQLGenerator())

步骤 2:使用 MLflow 代码中的模型功能进行日志记录

定义了 SQLGenerator 自定义 Python 模型后,下一步是使用 MLflow 的代码中的模型功能对其进行日志记录。这涉及使用日志模型标准 API 指定到 sql_model.py 脚本的路径,并使用 code_paths 参数将 workflow.py 作为依赖项包含在内。此方法确保在不同环境或另一台计算机上加载模型时,所有必要的代码文件都可用。

import mlflow
from definitions import (
EXPERIMENT_NAME,
MODEL_ALIAS,
REGISTERED_MODEL_NAME,
REMOTE_SERVER_URI,
)
from mlflow import MlflowClient

client = MlflowClient(tracking_uri=REMOTE_SERVER_URI)

mlflow.set_tracking_uri(REMOTE_SERVER_URI)
mlflow.set_experiment(EXPERIMENT_NAME)

with mlflow.start_run():
logged_model_info = mlflow.pyfunc.log_model(
python_model="sql_model.py",
artifact_path="sql_generator",
registered_model_name=REGISTERED_MODEL_NAME,
code_paths=["workflow.py"],
)

client.set_registered_model_alias(
REGISTERED_MODEL_NAME, MODEL_ALIAS, logged_model_info.registered_model_version
)

在 MLflow UI 中,存储的模型包括 sql_model.pyworkflow.py 脚本,它们作为运行中的工件存在。这种代码日志记录功能不仅记录模型的参数和指标,还捕获定义其功能的代码。这确保了可观察性、无缝跟踪以及直接通过 UI 进行的简单调试。但是,至关重要的是要确保永远不要将 API 密钥或凭据等敏感元素硬编码到这些脚本中。由于代码按原样存储,因此包含的任何敏感信息都可能导致令牌泄漏并构成安全风险。相反,应使用环境变量、秘密管理系统或其他安全方法安全地管理敏感数据。

model_as_code_artifact

main.py 中使用已记录的多语言查询引擎

记录模型后,可以使用模型 URI 和标准的 mlflow.pyfunc.load_model API 从 MLflow Tracking Server 加载它。加载模型时,将执行 workflow.py 脚本以及 sql_model.py 脚本,以确保在调用 predict 方法时 get_workflow 函数可用。

执行以下代码,我们演示了我们的多语言查询引擎能够执行自然语言到 SQL 的生成和查询执行。

import os
import logging

import mlflow
from database import setup_database
from definitions import (
EXPERIMENT_NAME,
MODEL_ALIAS,
REGISTERED_MODEL_NAME,
REMOTE_SERVER_URI,
)
from dotenv import load_dotenv
from vector_store import setup_vector_store

mlflow.set_tracking_uri(REMOTE_SERVER_URI)
mlflow.set_experiment(EXPERIMENT_NAME)
mlflow.langchain.autolog()

# Initialize the logger
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
_logger.addHandler(handler)


def main():
# Load environment variables from .env file
load_dotenv()

# Access secrets using os.getenv
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

# Setup database and vector store
conn = setup_database()
cursor = conn.cursor()
vector_store = setup_vector_store()

# Load the model
model_uri = f"models:/{REGISTERED_MODEL_NAME}@{MODEL_ALIAS}"
model = mlflow.pyfunc.load_model(model_uri)
model_input = {"conn": conn, "cursor": cursor, "vector_store": vector_store}
app = model.predict(model_input)

# save image
app.get_graph().draw_mermaid_png(
output_file_path="sql_agent_with_safety_checks.png"
)

# Example user interaction
_logger.info("Welcome to the SQL Assistant!")
while True:
question = input("\nEnter your SQL question (or type 'exit' to quit): ")
if question.lower() == "exit":
break

# Initialize the state with all required keys
initial_state = {
"messages": [("user", question)],
"iterations": 0,
"error": "",
"results": None,
"generation": None,
"no_records_found": False,
"translated_input": "", # Initialize translated_input
}

solution = app.invoke(initial_state)

# Check if an error was set during the safety check
if solution["error"] == "yes":
_logger.info("\nAssistant Message:\n")
_logger.info(solution["messages"][-1][1]) # Display the assistant's message
continue # Skip to the next iteration

# Extract the generated SQL query from solution["generation"]
sql_query = solution["generation"].sql_code
_logger.info("\nGenerated SQL Query:\n")
_logger.info(sql_query)

# Extract and display the query results
if solution.get("no_records_found"):
_logger.info("\nNo records found matching your query.")
elif "results" in solution and solution["results"] is not None:
_logger.info("\nQuery Results:\n")
for row in solution["results"]:
_logger.info(row)
else:
_logger.info("\nNo results returned or query did not execute successfully.")

_logger.info("Goodbye!")


if __name__ == "__main__":
main()

项目文件结构

该项目遵循简单的文件结构

file structure

MLflow 跟踪

MLflow 自动跟踪提供与各种 GenAI 库(如 LangChain、OpenAI、LlamaIndex、DSPy 和 AutoGen)的完全自动化集成。由于我们的 AI 工作流程是使用 LangGraph 构建的,因此我们可以通过启用 mlflow.langchain.autolog() 来激活自动 LangChain 跟踪。

借助 LangChain 自动日志记录,只要在链上调用调用 API,跟踪就会自动记录到活动的 MLflow 实验中。这种无缝集成确保捕获每次交互以进行监控和分析。

在 MLflow 中查看跟踪

通过导航到感兴趣的 MLflow 实验并单击“跟踪”选项卡,可以轻松访问跟踪。进入后,选择特定的跟踪可以提供详细的执行信息。

每个跟踪包括

  1. 执行图:工作流程步骤的可视化。
  2. 输入和输出:每个步骤处理的数据的详细日志。

通过利用 MLflow 跟踪,我们可以全面了解整个图形执行情况。AI 工作流程图通常感觉像一个黑匣子,使得调试和理解每个步骤中发生的事情具有挑战性。但是,只需一行代码即可启用跟踪,MLflow 就可以清晰、详细地了解工作流程,从而使开发人员可以有效地调试、监控和优化图的每个节点,从而确保我们的多语言查询引擎保持透明、可审计和可扩展。

mlflow_tracing_gif

结论

在整篇文章中,我们探讨了使用 LangGraph 和 MLflow 构建和管理多语言查询引擎的过程。通过将 LangGraph 的动态 AI 工作流程与 MLflow 的强大生命周期管理和跟踪功能集成,我们创建了一个系统,该系统不仅可以提供准确、高效的自然语言到 SQL 的生成和执行,而且还具有透明性和可扩展性。