使用 pipelines 为 Web 服务器提供推理服务¶
创建一个推理引擎是一个复杂的话题,"最佳"解决方案通常取决于你的具体需求。你是在使用 CPU 还是 GPU?你追求的是最低延迟、最高吞吐量、支持多个模型,还是只优化一个特定模型?虽然有多种方法可以解决这个问题,但我们提供的是一种通用的入门方案,可能不是最优化的解决方案,但适合初学者使用。
关键概念¶
我们可以使用一个迭代器来处理请求,就像处理数据集一样,因为 Web 服务器本质上是一个等待请求并逐个处理它们的系统。通常,Web 服务器是多路复用的(多线程、异步等),以并发处理多个请求。然而,推理 pipelines(尤其是底层模型)不太适合并行处理,它们会占用大量内存,所以最好在运行时给它们提供所有可用资源,尤其是计算密集型任务。
方案设计¶
我们将通过让 Web 服务器处理接收和发送请求的轻量级任务,同时使用单线程来处理实际的推理任务。这里我们使用 starlette 框架作为示例,但如果你使用其他框架,可能需要调整代码以实现相同的效果。
创建 server.py¶
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from transformers import pipeline
import asyncio
# 定义主页路由处理函数
async def homepage(request):
# 获取请求体
payload = await request.body()
string = payload.decode("utf-8")
# 创建一个队列用于存储推理结果
response_q = asyncio.Queue()
# 将请求字符串和队列放入模型队列中
await request.app.model_queue.put((string, response_q))
# 从队列中获取推理结果并返回
output = await response_q.get()
return JSONResponse(output)
# 定义服务器循环处理函数
async def server_loop(q):
# 初始化推理 pipelines
pipe = pipeline(model="google-bert/bert-base-uncased")
while True:
# 从队列中获取请求字符串和结果队列
(string, response_q) = await q.get()
# 执行推理并返回结果
out = pipe(string)
await response_q.put(out)
# 创建 Starlette 应用实例
app = Starlette(
routes=[
Route("/", homepage, methods=["POST"]),
],
)
# 定义启动事件
@app.on_event("startup")
async def startup_event():
q = asyncio.Queue() # 创建队列
app.model_queue = q # 将队列绑定到应用实例
asyncio.create_task(server_loop(q)) # 创建推理循环任务
启动服务器¶
uvicorn server:app
测试服务器¶
curl -X POST -d "test [MASK]" http://localhost:8000/
响应示例:
(string, rq) = await q.get()
strings = []
queues = []
while True:
try:
# 尝试在1毫秒内获取队列中的请求
(string, rq) = await asyncio.wait_for(q.get(), timeout=0.001) # 1ms
except asyncio.exceptions.TimeoutError:
break
strings.append(string)
queues.append(rq)
# 批量推理
outs = pipe(strings, batch_size=len(strings))
# 将结果放入各自的队列中
for rq, out in zip(queues, outs):
await rq.put(out)
需要考虑的其他问题¶
错误检查¶
生产环境中可能会遇到各种问题,例如内存不足、磁盘空间不足、模型加载失败、请求格式错误、模型配置错误等。通常,服务器应该向用户返回错误信息,因此在代码中添加 try..except 语句来捕获并显示错误是一个好主意。但请记住,暴露所有错误信息可能会带来安全风险,具体取决于你的安全要求。
断路器¶
Web 服务器在过载时返回适当的错误(例如 503 或 504)而不是无休止地等待请求,通常会让服务器表现得更好。可以通过检查队列大小来提前返回错误,以防止服务器在高负载下崩溃。
阻塞主线程¶
目前,PyTorch 不支持异步处理,计算任务会阻塞主线程。因此,最好将 PyTorch 运行在独立的线程或进程中。这虽然会使代码复杂化,但在单个推理任务耗时较长时非常重要。
动态批处理¶
批处理并不总是比逐个处理更有效(详细信息见 批处理详情)。但在某些情况下,例如处理非常大的模型(如 BLOOM)时,动态批处理是提高推理效率的关键。