FastAPI性能优化实战:从入门到精通
Orion K Lv6

FastAPI性能优化实战:从入门到精通

作为一名在生产环境中使用FastAPI超过2年的开发者,我想分享一些在实际项目中积累的性能优化经验。这些技巧帮助我们的API响应时间从平均300ms优化到了50ms以下,并发处理能力提升了5倍。

异步编程优化

1. 正确使用异步函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# ❌ 错误:在异步函数中使用同步操作
@app.get("/users/{user_id}")
async def get_user(user_id: int):
# 这会阻塞事件循环
user = requests.get(f"https://api.example.com/users/{user_id}")
return user.json()

# ✅ 正确:使用异步HTTP客户端
import httpx

@app.get("/users/{user_id}")
async def get_user(user_id: int):
async with httpx.AsyncClient() as client:
response = await client.get(f"https://api.example.com/users/{user_id}")
return response.json()

2. 并发处理多个异步操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import asyncio
from typing import List

@app.get("/dashboard")
async def get_dashboard_data(user_id: int):
# ❌ 串行执行,总时间 = 各操作时间之和
user_info = await get_user_info(user_id)
user_orders = await get_user_orders(user_id)
user_notifications = await get_user_notifications(user_id)

# ✅ 并行执行,总时间 ≈ 最慢操作的时间
user_info, user_orders, user_notifications = await asyncio.gather(
get_user_info(user_id),
get_user_orders(user_id),
get_user_notifications(user_id)
)

return {
"user": user_info,
"orders": user_orders,
"notifications": user_notifications
}

# 更高级的并发控制
async def fetch_user_data_with_timeout(user_id: int):
try:
# 设置超时时间,避免慢查询影响整体性能
async with asyncio.timeout(5.0): # Python 3.11+
return await asyncio.gather(
get_user_info(user_id),
get_user_orders(user_id),
get_user_notifications(user_id),
return_exceptions=True # 不让单个失败影响其他操作
)
except asyncio.TimeoutError:
# 返回部分数据或缓存数据
return await get_cached_user_data(user_id)

3. 异步生成器优化大数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from fastapi.responses import StreamingResponse
import json

@app.get("/export/users")
async def export_users():
"""流式导出用户数据,避免内存溢出"""

async def generate_user_data():
yield '{"users": ['

first = True
async for user in get_users_stream(): # 异步生成器
if not first:
yield ','
else:
first = False

yield json.dumps({
"id": user.id,
"name": user.name,
"email": user.email
})

yield ']}'

return StreamingResponse(
generate_user_data(),
media_type="application/json",
headers={"Content-Disposition": "attachment; filename=users.json"}
)

# 数据库异步生成器
async def get_users_stream():
"""分批获取用户数据,避免一次性加载所有数据"""
offset = 0
batch_size = 1000

while True:
users = await db.execute(
select(User).offset(offset).limit(batch_size)
)
user_list = users.scalars().all()

if not user_list:
break

for user in user_list:
yield user

offset += batch_size

数据库优化

1. 连接池优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool

# 优化的数据库引擎配置
engine = create_async_engine(
DATABASE_URL,
# 连接池配置
poolclass=QueuePool,
pool_size=20, # 连接池大小
max_overflow=30, # 最大溢出连接数
pool_pre_ping=True, # 连接前检查
pool_recycle=3600, # 连接回收时间(秒)

# 性能优化
echo=False, # 生产环境关闭SQL日志
future=True,

# 连接参数优化
connect_args={
"server_settings": {
"application_name": "fastapi_app",
"jit": "off", # 对于简单查询,关闭JIT可能更快
}
}
)

AsyncSessionLocal = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)

2. 查询优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from sqlalchemy.orm import selectinload, joinedload

# ❌ N+1查询问题
@app.get("/posts")
async def get_posts():
posts = await db.execute(select(Post))
result = []
for post in posts.scalars():
# 每个post都会触发一次查询
author = await db.execute(select(User).where(User.id == post.author_id))
result.append({
"post": post,
"author": author.scalar_one()
})
return result

# ✅ 使用预加载解决N+1问题
@app.get("/posts")
async def get_posts():
posts = await db.execute(
select(Post)
.options(joinedload(Post.author)) # 一次查询获取所有数据
.limit(50) # 限制返回数量
)

return [
{
"post": post,
"author": post.author
}
for post in posts.scalars().unique()
]

# 复杂关系的优化
@app.get("/users/{user_id}/full-profile")
async def get_user_full_profile(user_id: int):
user = await db.execute(
select(User)
.options(
selectinload(User.posts).selectinload(Post.comments),
selectinload(User.followers),
joinedload(User.profile)
)
.where(User.id == user_id)
)

return user.scalar_one()

3. 批量操作优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from sqlalchemy.dialects.postgresql import insert

@app.post("/users/batch")
async def create_users_batch(users: List[UserCreate]):
# ❌ 逐个插入
for user_data in users:
user = User(**user_data.dict())
db.add(user)
await db.commit()

# ✅ 批量插入
user_dicts = [user.dict() for user in users]
await db.execute(insert(User), user_dicts)
await db.commit()

# PostgreSQL特有的UPSERT优化
@app.post("/users/upsert")
async def upsert_users(users: List[UserCreate]):
stmt = insert(User).values([user.dict() for user in users])
stmt = stmt.on_conflict_do_update(
index_elements=['email'],
set_=dict(
name=stmt.excluded.name,
updated_at=func.now()
)
)
await db.execute(stmt)
await db.commit()

缓存策略

1. Redis缓存实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import redis.asyncio as redis
import json
from functools import wraps
import hashlib

# Redis连接池
redis_pool = redis.ConnectionPool.from_url(
"redis://localhost:6379",
max_connections=20,
retry_on_timeout=True
)
redis_client = redis.Redis(connection_pool=redis_pool)

def cache_result(expire_time: int = 300):
"""缓存装饰器"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# 生成缓存键
cache_key = f"{func.__name__}:{hashlib.md5(str(args + tuple(kwargs.items())).encode()).hexdigest()}"

# 尝试从缓存获取
cached_result = await redis_client.get(cache_key)
if cached_result:
return json.loads(cached_result)

# 执行函数并缓存结果
result = await func(*args, **kwargs)
await redis_client.setex(
cache_key,
expire_time,
json.dumps(result, default=str)
)

return result
return wrapper
return decorator

@cache_result(expire_time=600) # 缓存10分钟
async def get_popular_posts():
posts = await db.execute(
select(Post)
.where(Post.view_count > 1000)
.order_by(Post.view_count.desc())
.limit(10)
)
return [post.dict() for post in posts.scalars()]

2. 分层缓存策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import Optional
import time

class CacheManager:
def __init__(self):
self.memory_cache = {} # 内存缓存
self.redis_client = redis_client # Redis缓存

async def get(self, key: str) -> Optional[any]:
# L1: 内存缓存
if key in self.memory_cache:
data, expire_time = self.memory_cache[key]
if time.time() < expire_time:
return data
else:
del self.memory_cache[key]

# L2: Redis缓存
redis_data = await self.redis_client.get(key)
if redis_data:
data = json.loads(redis_data)
# 回填内存缓存
self.memory_cache[key] = (data, time.time() + 60) # 内存缓存1分钟
return data

return None

async def set(self, key: str, value: any, expire_time: int = 300):
# 同时设置两级缓存
self.memory_cache[key] = (value, time.time() + min(expire_time, 60))
await self.redis_client.setex(key, expire_time, json.dumps(value, default=str))

cache_manager = CacheManager()

@app.get("/posts/{post_id}")
async def get_post(post_id: int):
cache_key = f"post:{post_id}"

# 尝试从缓存获取
cached_post = await cache_manager.get(cache_key)
if cached_post:
return cached_post

# 从数据库获取
post = await db.execute(select(Post).where(Post.id == post_id))
post_data = post.scalar_one().dict()

# 缓存结果
await cache_manager.set(cache_key, post_data, expire_time=600)

return post_data

3. 缓存预热和更新策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from apscheduler.schedulers.asyncio import AsyncIOScheduler

scheduler = AsyncIOScheduler()

@scheduler.scheduled_job('interval', minutes=30)
async def warm_up_cache():
"""定期预热热点数据缓存"""

# 预热热门文章
popular_posts = await get_popular_posts()
await cache_manager.set("popular_posts", popular_posts, expire_time=1800)

# 预热用户统计数据
user_stats = await get_user_statistics()
await cache_manager.set("user_stats", user_stats, expire_time=3600)

# 缓存失效策略
@app.post("/posts/{post_id}")
async def update_post(post_id: int, post_data: PostUpdate):
# 更新数据库
await db.execute(
update(Post)
.where(Post.id == post_id)
.values(**post_data.dict())
)
await db.commit()

# 失效相关缓存
await redis_client.delete(f"post:{post_id}")
await redis_client.delete("popular_posts") # 如果更新可能影响热门列表

return {"message": "Post updated successfully"}

响应优化

1. 响应压缩

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from fastapi.middleware.gzip import GZipMiddleware

app.add_middleware(GZipMiddleware, minimum_size=1000)

# 自定义压缩中间件
import gzip
import brotli
from starlette.middleware.base import BaseHTTPMiddleware

class CompressionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)

# 检查客户端支持的压缩格式
accept_encoding = request.headers.get("accept-encoding", "")

if "br" in accept_encoding and len(response.body) > 1000:
# 使用Brotli压缩(更好的压缩率)
compressed_body = brotli.compress(response.body)
response.headers["content-encoding"] = "br"
response.headers["content-length"] = str(len(compressed_body))
response.body = compressed_body
elif "gzip" in accept_encoding and len(response.body) > 1000:
# 使用Gzip压缩
compressed_body = gzip.compress(response.body)
response.headers["content-encoding"] = "gzip"
response.headers["content-length"] = str(len(compressed_body))
response.body = compressed_body

return response

app.add_middleware(CompressionMiddleware)

2. 分页优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import Optional

class PaginationParams:
def __init__(
self,
page: int = Query(1, ge=1, description="页码"),
size: int = Query(20, ge=1, le=100, description="每页大小"),
cursor: Optional[str] = Query(None, description="游标分页")
):
self.page = page
self.size = size
self.cursor = cursor
self.offset = (page - 1) * size

# 传统分页(适合小数据集)
@app.get("/posts")
async def get_posts(pagination: PaginationParams = Depends()):
posts = await db.execute(
select(Post)
.order_by(Post.created_at.desc())
.offset(pagination.offset)
.limit(pagination.size)
)

# 获取总数(注意:这个查询可能很慢)
total = await db.execute(select(func.count(Post.id)))

return {
"posts": posts.scalars().all(),
"pagination": {
"page": pagination.page,
"size": pagination.size,
"total": total.scalar(),
"pages": (total.scalar() + pagination.size - 1) // pagination.size
}
}

# 游标分页(适合大数据集)
@app.get("/posts/cursor")
async def get_posts_cursor(pagination: PaginationParams = Depends()):
query = select(Post).order_by(Post.id.desc()).limit(pagination.size + 1)

if pagination.cursor:
# 解码游标
cursor_id = int(pagination.cursor)
query = query.where(Post.id < cursor_id)

posts = await db.execute(query)
post_list = posts.scalars().all()

# 检查是否有下一页
has_next = len(post_list) > pagination.size
if has_next:
post_list = post_list[:-1]

# 生成下一页游标
next_cursor = str(post_list[-1].id) if post_list and has_next else None

return {
"posts": post_list,
"pagination": {
"size": pagination.size,
"has_next": has_next,
"next_cursor": next_cursor
}
}

监控和性能分析

1. 性能监控中间件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import time
import logging
from starlette.middleware.base import BaseHTTPMiddleware

logger = logging.getLogger(__name__)

class PerformanceMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
start_time = time.time()

# 记录请求信息
logger.info(f"Request started: {request.method} {request.url}")

response = await call_next(request)

# 计算处理时间
process_time = time.time() - start_time

# 添加响应头
response.headers["X-Process-Time"] = str(process_time)

# 记录慢查询
if process_time > 1.0: # 超过1秒的请求
logger.warning(
f"Slow request: {request.method} {request.url} "
f"took {process_time:.2f}s"
)

# 发送指标到监控系统
# metrics_client.timing('api.request_duration', process_time * 1000)
# metrics_client.increment(f'api.requests.{response.status_code}')

return response

app.add_middleware(PerformanceMiddleware)

2. 健康检查和指标端点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import psutil
import asyncio

@app.get("/health")
async def health_check():
"""健康检查端点"""
checks = {}

# 数据库连接检查
try:
await db.execute("SELECT 1")
checks["database"] = {"status": "healthy"}
except Exception as e:
checks["database"] = {"status": "unhealthy", "error": str(e)}

# Redis连接检查
try:
await redis_client.ping()
checks["redis"] = {"status": "healthy"}
except Exception as e:
checks["redis"] = {"status": "unhealthy", "error": str(e)}

# 系统资源检查
cpu_percent = psutil.cpu_percent()
memory_percent = psutil.virtual_memory().percent

checks["system"] = {
"status": "healthy" if cpu_percent < 80 and memory_percent < 80 else "warning",
"cpu_percent": cpu_percent,
"memory_percent": memory_percent
}

overall_status = "healthy" if all(
check["status"] == "healthy" for check in checks.values()
) else "unhealthy"

return {
"status": overall_status,
"checks": checks,
"timestamp": time.time()
}

@app.get("/metrics")
async def get_metrics():
"""应用指标端点"""
return {
"active_connections": len(engine.pool.checkedout()),
"pool_size": engine.pool.size(),
"checked_in_connections": len(engine.pool.checkedin()),
"memory_usage": psutil.Process().memory_info().rss / 1024 / 1024, # MB
"cpu_percent": psutil.Process().cpu_percent(),
}

部署优化

1. Gunicorn配置优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# gunicorn.conf.py
import multiprocessing

# 基本配置
bind = "0.0.0.0:8000"
workers = multiprocessing.cpu_count() * 2 + 1
worker_class = "uvicorn.workers.UvicornWorker"
worker_connections = 1000

# 性能优化
max_requests = 1000 # 处理请求数后重启worker
max_requests_jitter = 100 # 随机化重启时间
preload_app = True # 预加载应用

# 超时设置
timeout = 30
keepalive = 5
graceful_timeout = 30

# 日志配置
accesslog = "-"
errorlog = "-"
loglevel = "info"
access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s" %(D)s'

# 进程命名
proc_name = "fastapi_app"

def when_ready(server):
server.log.info("Server is ready. Spawning workers")

def worker_int(worker):
worker.log.info("worker received INT or QUIT signal")

def pre_fork(server, worker):
server.log.info("Worker spawned (pid: %s)", worker.pid)

2. Docker优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Dockerfile
FROM python:3.11-slim

# 安装系统依赖
RUN apt-get update && apt-get install -y \
gcc \
&& rm -rf /var/lib/apt/lists/*

# 设置工作目录
WORKDIR /app

# 复制依赖文件
COPY requirements.txt .

# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt

# 复制应用代码
COPY . .

# 创建非root用户
RUN useradd --create-home --shell /bin/bash app \
&& chown -R app:app /app
USER app

# 健康检查
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1

# 启动命令
CMD ["gunicorn", "-c", "gunicorn.conf.py", "main:app"]

总结

通过这些优化技巧,我们的FastAPI应用性能得到了显著提升:

  1. 异步编程:正确使用async/await,避免阻塞操作
  2. 数据库优化:连接池配置、查询优化、批量操作
  3. 缓存策略:多级缓存、预热机制、失效策略
  4. 响应优化:压缩、分页、流式响应
  5. 监控分析:性能监控、健康检查、指标收集
  6. 部署优化:Gunicorn配置、Docker优化

记住,性能优化是一个持续的过程,需要根据实际业务场景和监控数据来调整策略。不要过早优化,先确保功能正确,再根据性能瓶颈进行针对性优化。

你在FastAPI性能优化方面有什么经验或遇到过什么问题吗?欢迎在评论中分享讨论!

本站由 提供部署服务