AI 网关集成
了解如何将 MLflow AI 网关与应用程序、框架和生产系统集成。
应用程序集成
FastAPI 集成
构建 REST API,将请求代理到 AI 网关,并添加您自己的业务逻辑、身份验证和数据处理
from fastapi import FastAPI, HTTPException
from mlflow.deployments import get_deploy_client
app = FastAPI()
client = get_deploy_client("https://:5000")
@app.post("/chat")
async def chat_endpoint(message: str):
try:
response = client.predict(
endpoint="chat", inputs={"messages": [{"role": "user", "content": message}]}
)
return {"response": response["choices"][0]["message"]["content"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/embed")
async def embed_endpoint(text: str):
try:
response = client.predict(endpoint="embeddings", inputs={"input": text})
return {"embedding": response["data"][0]["embedding"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
Flask 集成
使用熟悉的请求/响应模式创建集成 AI 功能的 Flask 应用程序
from flask import Flask, request, jsonify
from mlflow.deployments import get_deploy_client
app = Flask(__name__)
client = get_deploy_client("https://:5000")
@app.route("/chat", methods=["POST"])
def chat():
try:
data = request.get_json()
response = client.predict(
endpoint="chat",
inputs={"messages": [{"role": "user", "content": data["message"]}]},
)
return jsonify({"response": response["choices"][0]["message"]["content"]})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(debug=True)
异步/等待支持
使用 asyncio 高效处理多个并发请求,以实现高吞吐量应用程序
import asyncio
import aiohttp
import json
async def async_query_gateway(endpoint, data):
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://:5000/gateway/{endpoint}/invocations",
headers={"Content-Type": "application/json"},
data=json.dumps(data),
) as response:
return await response.json()
async def main():
# Concurrent requests
tasks = [
async_query_gateway(
"chat", {"messages": [{"role": "user", "content": f"Question {i}"}]}
)
for i in range(5)
]
responses = await asyncio.gather(*tasks)
for i, response in enumerate(responses):
print(f"Response {i}: {response['choices'][0]['message']['content']}")
# Run async example
asyncio.run(main())
LangChain 集成
设置
LangChain 提供可直接与 AI 网关配合使用的预构建组件,从而轻松集成 LangChain 的工具和链生态系统
from langchain_community.llms import MLflowAIGateway
from langchain_community.embeddings import MlflowAIGatewayEmbeddings
from langchain_community.chat_models import ChatMLflowAIGateway
# Configure LangChain to use your gateway
gateway_uri = "https://:5000"
聊天模型
创建通过您的网关路由的 LangChain 聊天模型,允许您在不更改应用程序代码的情况下切换提供商
# Chat model
chat = ChatMLflowAIGateway(
gateway_uri=gateway_uri,
route="chat",
params={
"temperature": 0.7,
"top_p": 0.95,
},
)
# Generate response
from langchain_core.messages import HumanMessage, SystemMessage
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="What is LangChain?"),
]
response = chat(messages)
print(response.content)
嵌入
使用网关驱动的嵌入进行向量搜索、语义相似性和 RAG 应用程序
# Embeddings
embeddings = MlflowAIGatewayEmbeddings(gateway_uri=gateway_uri, route="embeddings")
# Generate embeddings
text_embeddings = embeddings.embed_documents(
["This is a document", "This is another document"]
)
query_embedding = embeddings.embed_query("This is a query")
完整的 RAG 示例
使用网关构建完整的检索增强生成 (RAG) 系统,用于嵌入和聊天完成
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQA
# Load documents
loader = TextLoader("path/to/document.txt")
documents = loader.load()
# Split documents
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
# Create vector store
vectorstore = FAISS.from_documents(docs, embeddings)
# Create QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=chat, chain_type="stuff", retriever=vectorstore.as_retriever()
)
# Query the system
question = "What is the main topic of the document?"
result = qa_chain.run(question)
print(result)
OpenAI 兼容性
AI 网关提供 OpenAI 兼容的端点,允许您以最少的代码更改迁移现有 OpenAI 应用程序
import openai
# Configure OpenAI client to use the gateway
openai.api_base = "https://:5000/gateway/chat"
openai.api_key = "not-needed" # Gateway handles authentication
# Use standard OpenAI client
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo", # Endpoint name in your gateway config
messages=[{"role": "user", "content": "Hello, AI Gateway!"}],
)
print(response.choices[0].message.content)
MLflow 模型集成
将您自己的自定义模型与外部提供商一起部署,为专有模型和第三方模型提供统一接口。
注册模型
使用 MLflow 的标准工作流程训练和注册您的模型,然后通过网关公开它们
import mlflow
import mlflow.pyfunc
# Log and register a model
with mlflow.start_run():
# Your model training code here
mlflow.pyfunc.log_model(
name="my_model",
python_model=MyCustomModel(),
registered_model_name="custom-chat-model",
)
# Deploy the model
# Then configure it in your gateway config.yaml:
endpoints:
- name: custom-model
endpoint_type: llm/v1/chat
model:
provider: mlflow-model-serving
name: custom-chat-model
config:
model_server_url: http://:5001
生产最佳实践
性能优化
- 连接池:对高吞吐量应用程序使用持久 HTTP 连接
- 批量请求:尽可能对多个请求进行分组
- 异步操作:对并发请求使用异步/等待
- 缓存:对重复查询实施响应缓存
错误处理
import time
from mlflow.deployments import get_deploy_client
from mlflow.exceptions import MlflowException
def robust_query(client, endpoint, inputs, max_retries=3):
for attempt in range(max_retries):
try:
return client.predict(endpoint=endpoint, inputs=inputs)
except MlflowException as e:
if attempt < max_retries - 1:
time.sleep(2**attempt) # Exponential backoff
continue
raise e
# Usage
client = get_deploy_client("https://:5000")
response = robust_query(
client, "chat", {"messages": [{"role": "user", "content": "Hello"}]}
)
安全性
- 生产中使用 HTTPS
- 在应用程序级别实施身份验证
- 在发送到网关之前验证输入
- 监控使用情况并实施速率限制
监控和日志记录
import logging
from mlflow.deployments import get_deploy_client
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def monitored_query(client, endpoint, inputs):
start_time = time.time()
try:
logger.info(f"Querying endpoint: {endpoint}")
response = client.predict(endpoint=endpoint, inputs=inputs)
duration = time.time() - start_time
logger.info(f"Query completed in {duration:.2f}s")
return response
except Exception as e:
duration = time.time() - start_time
logger.error(f"Query failed after {duration:.2f}s: {e}")
raise
负载均衡
对于高可用性设置,考虑运行多个网关实例
import random
from mlflow.deployments import get_deploy_client
# Multiple gateway instances
gateway_urls = ["http://gateway1:5000", "http://gateway2:5000", "http://gateway3:5000"]
def get_client():
url = random.choice(gateway_urls)
return get_deploy_client(url)
# Use with automatic failover
def resilient_query(endpoint, inputs, max_retries=3):
for attempt in range(max_retries):
try:
client = get_client()
return client.predict(endpoint=endpoint, inputs=inputs)
except Exception as e:
if attempt < max_retries - 1:
continue
raise e
健康与监控
# Check gateway health via HTTP
import requests
def check_gateway_health(gateway_url):
try:
response = requests.get(f"{gateway_url}/health")
return {
"status": response.status_code,
"healthy": response.status_code == 200,
"response": response.json() if response.status_code == 200 else None,
}
except requests.RequestException as e:
return {"status": "error", "healthy": False, "error": str(e)}
# Example usage
health = check_gateway_health("https://:5000")
print(f"Gateway Health: {health}")