跳到主要内容

AI 网关集成

了解如何将 MLflow AI 网关与应用程序、框架和生产系统集成。

应用程序集成

FastAPI 集成

构建 REST API,将请求代理到 AI 网关,并添加您自己的业务逻辑、身份验证和数据处理

from fastapi import FastAPI, HTTPException
from mlflow.deployments import get_deploy_client

app = FastAPI()
client = get_deploy_client("https://:5000")


@app.post("/chat")
async def chat_endpoint(message: str):
try:
response = client.predict(
endpoint="chat", inputs={"messages": [{"role": "user", "content": message}]}
)
return {"response": response["choices"][0]["message"]["content"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@app.post("/embed")
async def embed_endpoint(text: str):
try:
response = client.predict(endpoint="embeddings", inputs={"input": text})
return {"embedding": response["data"][0]["embedding"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

Flask 集成

使用熟悉的请求/响应模式创建集成 AI 功能的 Flask 应用程序

from flask import Flask, request, jsonify
from mlflow.deployments import get_deploy_client

app = Flask(__name__)
client = get_deploy_client("https://:5000")


@app.route("/chat", methods=["POST"])
def chat():
try:
data = request.get_json()
response = client.predict(
endpoint="chat",
inputs={"messages": [{"role": "user", "content": data["message"]}]},
)
return jsonify({"response": response["choices"][0]["message"]["content"]})
except Exception as e:
return jsonify({"error": str(e)}), 500


if __name__ == "__main__":
app.run(debug=True)

异步/等待支持

使用 asyncio 高效处理多个并发请求,以实现高吞吐量应用程序

import asyncio
import aiohttp
import json


async def async_query_gateway(endpoint, data):
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://:5000/gateway/{endpoint}/invocations",
headers={"Content-Type": "application/json"},
data=json.dumps(data),
) as response:
return await response.json()


async def main():
# Concurrent requests
tasks = [
async_query_gateway(
"chat", {"messages": [{"role": "user", "content": f"Question {i}"}]}
)
for i in range(5)
]

responses = await asyncio.gather(*tasks)
for i, response in enumerate(responses):
print(f"Response {i}: {response['choices'][0]['message']['content']}")


# Run async example
asyncio.run(main())

LangChain 集成

设置

LangChain 提供可直接与 AI 网关配合使用的预构建组件,从而轻松集成 LangChain 的工具和链生态系统

from langchain_community.llms import MLflowAIGateway
from langchain_community.embeddings import MlflowAIGatewayEmbeddings
from langchain_community.chat_models import ChatMLflowAIGateway

# Configure LangChain to use your gateway
gateway_uri = "https://:5000"

聊天模型

创建通过您的网关路由的 LangChain 聊天模型,允许您在不更改应用程序代码的情况下切换提供商

# Chat model
chat = ChatMLflowAIGateway(
gateway_uri=gateway_uri,
route="chat",
params={
"temperature": 0.7,
"top_p": 0.95,
},
)

# Generate response
from langchain_core.messages import HumanMessage, SystemMessage

messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="What is LangChain?"),
]

response = chat(messages)
print(response.content)

嵌入

使用网关驱动的嵌入进行向量搜索、语义相似性和 RAG 应用程序

# Embeddings
embeddings = MlflowAIGatewayEmbeddings(gateway_uri=gateway_uri, route="embeddings")

# Generate embeddings
text_embeddings = embeddings.embed_documents(
["This is a document", "This is another document"]
)

query_embedding = embeddings.embed_query("This is a query")

完整的 RAG 示例

使用网关构建完整的检索增强生成 (RAG) 系统,用于嵌入和聊天完成

from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQA

# Load documents
loader = TextLoader("path/to/document.txt")
documents = loader.load()

# Split documents
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)

# Create vector store
vectorstore = FAISS.from_documents(docs, embeddings)

# Create QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=chat, chain_type="stuff", retriever=vectorstore.as_retriever()
)

# Query the system
question = "What is the main topic of the document?"
result = qa_chain.run(question)
print(result)

OpenAI 兼容性

AI 网关提供 OpenAI 兼容的端点,允许您以最少的代码更改迁移现有 OpenAI 应用程序

import openai

# Configure OpenAI client to use the gateway
openai.api_base = "https://:5000/gateway/chat"
openai.api_key = "not-needed" # Gateway handles authentication

# Use standard OpenAI client
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo", # Endpoint name in your gateway config
messages=[{"role": "user", "content": "Hello, AI Gateway!"}],
)

print(response.choices[0].message.content)

MLflow 模型集成

将您自己的自定义模型与外部提供商一起部署,为专有模型和第三方模型提供统一接口。

注册模型

使用 MLflow 的标准工作流程训练和注册您的模型,然后通过网关公开它们

import mlflow
import mlflow.pyfunc

# Log and register a model
with mlflow.start_run():
# Your model training code here
mlflow.pyfunc.log_model(
name="my_model",
python_model=MyCustomModel(),
registered_model_name="custom-chat-model",
)

# Deploy the model
# Then configure it in your gateway config.yaml:
endpoints:
- name: custom-model
endpoint_type: llm/v1/chat
model:
provider: mlflow-model-serving
name: custom-chat-model
config:
model_server_url: http://:5001

生产最佳实践

性能优化

  1. 连接池:对高吞吐量应用程序使用持久 HTTP 连接
  2. 批量请求:尽可能对多个请求进行分组
  3. 异步操作:对并发请求使用异步/等待
  4. 缓存:对重复查询实施响应缓存

错误处理

import time
from mlflow.deployments import get_deploy_client
from mlflow.exceptions import MlflowException


def robust_query(client, endpoint, inputs, max_retries=3):
for attempt in range(max_retries):
try:
return client.predict(endpoint=endpoint, inputs=inputs)
except MlflowException as e:
if attempt < max_retries - 1:
time.sleep(2**attempt) # Exponential backoff
continue
raise e


# Usage
client = get_deploy_client("https://:5000")
response = robust_query(
client, "chat", {"messages": [{"role": "user", "content": "Hello"}]}
)

安全性

  1. 生产中使用 HTTPS
  2. 在应用程序级别实施身份验证
  3. 在发送到网关之前验证输入
  4. 监控使用情况并实施速率限制

监控和日志记录

import logging
from mlflow.deployments import get_deploy_client

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def monitored_query(client, endpoint, inputs):
start_time = time.time()
try:
logger.info(f"Querying endpoint: {endpoint}")
response = client.predict(endpoint=endpoint, inputs=inputs)
duration = time.time() - start_time
logger.info(f"Query completed in {duration:.2f}s")
return response
except Exception as e:
duration = time.time() - start_time
logger.error(f"Query failed after {duration:.2f}s: {e}")
raise

负载均衡

对于高可用性设置,考虑运行多个网关实例

import random
from mlflow.deployments import get_deploy_client

# Multiple gateway instances
gateway_urls = ["http://gateway1:5000", "http://gateway2:5000", "http://gateway3:5000"]


def get_client():
url = random.choice(gateway_urls)
return get_deploy_client(url)


# Use with automatic failover
def resilient_query(endpoint, inputs, max_retries=3):
for attempt in range(max_retries):
try:
client = get_client()
return client.predict(endpoint=endpoint, inputs=inputs)
except Exception as e:
if attempt < max_retries - 1:
continue
raise e

健康与监控

# Check gateway health via HTTP
import requests


def check_gateway_health(gateway_url):
try:
response = requests.get(f"{gateway_url}/health")
return {
"status": response.status_code,
"healthy": response.status_code == 200,
"response": response.json() if response.status_code == 200 else None,
}
except requests.RequestException as e:
return {"status": "error", "healthy": False, "error": str(e)}


# Example usage
health = check_gateway_health("https://:5000")
print(f"Gateway Health: {health}")

后续步骤