FastAPI性能优化实战:从入门到精通
作为一名在生产环境中使用FastAPI超过2年的开发者,我想分享一些在实际项目中积累的性能优化经验。这些技巧帮助我们的API响应时间从平均300ms优化到了50ms以下,并发处理能力提升了5倍。
异步编程优化
1. 正确使用异步函数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| @app.get("/users/{user_id}") async def get_user(user_id: int): user = requests.get(f"https://api.example.com/users/{user_id}") return user.json()
import httpx
@app.get("/users/{user_id}") async def get_user(user_id: int): async with httpx.AsyncClient() as client: response = await client.get(f"https://api.example.com/users/{user_id}") return response.json()
|
2. 并发处理多个异步操作
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| import asyncio from typing import List
@app.get("/dashboard") async def get_dashboard_data(user_id: int): user_info = await get_user_info(user_id) user_orders = await get_user_orders(user_id) user_notifications = await get_user_notifications(user_id) user_info, user_orders, user_notifications = await asyncio.gather( get_user_info(user_id), get_user_orders(user_id), get_user_notifications(user_id) ) return { "user": user_info, "orders": user_orders, "notifications": user_notifications }
async def fetch_user_data_with_timeout(user_id: int): try: async with asyncio.timeout(5.0): return await asyncio.gather( get_user_info(user_id), get_user_orders(user_id), get_user_notifications(user_id), return_exceptions=True ) except asyncio.TimeoutError: return await get_cached_user_data(user_id)
|
3. 异步生成器优化大数据集
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
| from fastapi.responses import StreamingResponse import json
@app.get("/export/users") async def export_users(): """流式导出用户数据,避免内存溢出""" async def generate_user_data(): yield '{"users": [' first = True async for user in get_users_stream(): if not first: yield ',' else: first = False yield json.dumps({ "id": user.id, "name": user.name, "email": user.email }) yield ']}' return StreamingResponse( generate_user_data(), media_type="application/json", headers={"Content-Disposition": "attachment; filename=users.json"} )
async def get_users_stream(): """分批获取用户数据,避免一次性加载所有数据""" offset = 0 batch_size = 1000 while True: users = await db.execute( select(User).offset(offset).limit(batch_size) ) user_list = users.scalars().all() if not user_list: break for user in user_list: yield user offset += batch_size
|
数据库优化
1. 连接池优化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import QueuePool
engine = create_async_engine( DATABASE_URL, poolclass=QueuePool, pool_size=20, max_overflow=30, pool_pre_ping=True, pool_recycle=3600, echo=False, future=True, connect_args={ "server_settings": { "application_name": "fastapi_app", "jit": "off", } } )
AsyncSessionLocal = sessionmaker( engine, class_=AsyncSession, expire_on_commit=False )
|
2. 查询优化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| from sqlalchemy.orm import selectinload, joinedload
@app.get("/posts") async def get_posts(): posts = await db.execute(select(Post)) result = [] for post in posts.scalars(): author = await db.execute(select(User).where(User.id == post.author_id)) result.append({ "post": post, "author": author.scalar_one() }) return result
@app.get("/posts") async def get_posts(): posts = await db.execute( select(Post) .options(joinedload(Post.author)) .limit(50) ) return [ { "post": post, "author": post.author } for post in posts.scalars().unique() ]
@app.get("/users/{user_id}/full-profile") async def get_user_full_profile(user_id: int): user = await db.execute( select(User) .options( selectinload(User.posts).selectinload(Post.comments), selectinload(User.followers), joinedload(User.profile) ) .where(User.id == user_id) ) return user.scalar_one()
|
3. 批量操作优化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| from sqlalchemy.dialects.postgresql import insert
@app.post("/users/batch") async def create_users_batch(users: List[UserCreate]): for user_data in users: user = User(**user_data.dict()) db.add(user) await db.commit() user_dicts = [user.dict() for user in users] await db.execute(insert(User), user_dicts) await db.commit()
@app.post("/users/upsert") async def upsert_users(users: List[UserCreate]): stmt = insert(User).values([user.dict() for user in users]) stmt = stmt.on_conflict_do_update( index_elements=['email'], set_=dict( name=stmt.excluded.name, updated_at=func.now() ) ) await db.execute(stmt) await db.commit()
|
缓存策略
1. Redis缓存实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| import redis.asyncio as redis import json from functools import wraps import hashlib
redis_pool = redis.ConnectionPool.from_url( "redis://localhost:6379", max_connections=20, retry_on_timeout=True ) redis_client = redis.Redis(connection_pool=redis_pool)
def cache_result(expire_time: int = 300): """缓存装饰器""" def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): cache_key = f"{func.__name__}:{hashlib.md5(str(args + tuple(kwargs.items())).encode()).hexdigest()}" cached_result = await redis_client.get(cache_key) if cached_result: return json.loads(cached_result) result = await func(*args, **kwargs) await redis_client.setex( cache_key, expire_time, json.dumps(result, default=str) ) return result return wrapper return decorator
@cache_result(expire_time=600) async def get_popular_posts(): posts = await db.execute( select(Post) .where(Post.view_count > 1000) .order_by(Post.view_count.desc()) .limit(10) ) return [post.dict() for post in posts.scalars()]
|
2. 分层缓存策略
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| from typing import Optional import time
class CacheManager: def __init__(self): self.memory_cache = {} self.redis_client = redis_client async def get(self, key: str) -> Optional[any]: if key in self.memory_cache: data, expire_time = self.memory_cache[key] if time.time() < expire_time: return data else: del self.memory_cache[key] redis_data = await self.redis_client.get(key) if redis_data: data = json.loads(redis_data) self.memory_cache[key] = (data, time.time() + 60) return data return None async def set(self, key: str, value: any, expire_time: int = 300): self.memory_cache[key] = (value, time.time() + min(expire_time, 60)) await self.redis_client.setex(key, expire_time, json.dumps(value, default=str))
cache_manager = CacheManager()
@app.get("/posts/{post_id}") async def get_post(post_id: int): cache_key = f"post:{post_id}" cached_post = await cache_manager.get(cache_key) if cached_post: return cached_post post = await db.execute(select(Post).where(Post.id == post_id)) post_data = post.scalar_one().dict() await cache_manager.set(cache_key, post_data, expire_time=600) return post_data
|
3. 缓存预热和更新策略
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| from apscheduler.schedulers.asyncio import AsyncIOScheduler
scheduler = AsyncIOScheduler()
@scheduler.scheduled_job('interval', minutes=30) async def warm_up_cache(): """定期预热热点数据缓存""" popular_posts = await get_popular_posts() await cache_manager.set("popular_posts", popular_posts, expire_time=1800) user_stats = await get_user_statistics() await cache_manager.set("user_stats", user_stats, expire_time=3600)
@app.post("/posts/{post_id}") async def update_post(post_id: int, post_data: PostUpdate): await db.execute( update(Post) .where(Post.id == post_id) .values(**post_data.dict()) ) await db.commit() await redis_client.delete(f"post:{post_id}") await redis_client.delete("popular_posts") return {"message": "Post updated successfully"}
|
响应优化
1. 响应压缩
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware, minimum_size=1000)
import gzip import brotli from starlette.middleware.base import BaseHTTPMiddleware
class CompressionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): response = await call_next(request) accept_encoding = request.headers.get("accept-encoding", "") if "br" in accept_encoding and len(response.body) > 1000: compressed_body = brotli.compress(response.body) response.headers["content-encoding"] = "br" response.headers["content-length"] = str(len(compressed_body)) response.body = compressed_body elif "gzip" in accept_encoding and len(response.body) > 1000: compressed_body = gzip.compress(response.body) response.headers["content-encoding"] = "gzip" response.headers["content-length"] = str(len(compressed_body)) response.body = compressed_body return response
app.add_middleware(CompressionMiddleware)
|
2. 分页优化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
| from typing import Optional
class PaginationParams: def __init__( self, page: int = Query(1, ge=1, description="页码"), size: int = Query(20, ge=1, le=100, description="每页大小"), cursor: Optional[str] = Query(None, description="游标分页") ): self.page = page self.size = size self.cursor = cursor self.offset = (page - 1) * size
@app.get("/posts") async def get_posts(pagination: PaginationParams = Depends()): posts = await db.execute( select(Post) .order_by(Post.created_at.desc()) .offset(pagination.offset) .limit(pagination.size) ) total = await db.execute(select(func.count(Post.id))) return { "posts": posts.scalars().all(), "pagination": { "page": pagination.page, "size": pagination.size, "total": total.scalar(), "pages": (total.scalar() + pagination.size - 1) // pagination.size } }
@app.get("/posts/cursor") async def get_posts_cursor(pagination: PaginationParams = Depends()): query = select(Post).order_by(Post.id.desc()).limit(pagination.size + 1) if pagination.cursor: cursor_id = int(pagination.cursor) query = query.where(Post.id < cursor_id) posts = await db.execute(query) post_list = posts.scalars().all() has_next = len(post_list) > pagination.size if has_next: post_list = post_list[:-1] next_cursor = str(post_list[-1].id) if post_list and has_next else None return { "posts": post_list, "pagination": { "size": pagination.size, "has_next": has_next, "next_cursor": next_cursor } }
|
监控和性能分析
1. 性能监控中间件
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| import time import logging from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class PerformanceMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): start_time = time.time() logger.info(f"Request started: {request.method} {request.url}") response = await call_next(request) process_time = time.time() - start_time response.headers["X-Process-Time"] = str(process_time) if process_time > 1.0: logger.warning( f"Slow request: {request.method} {request.url} " f"took {process_time:.2f}s" ) return response
app.add_middleware(PerformanceMiddleware)
|
2. 健康检查和指标端点
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| import psutil import asyncio
@app.get("/health") async def health_check(): """健康检查端点""" checks = {} try: await db.execute("SELECT 1") checks["database"] = {"status": "healthy"} except Exception as e: checks["database"] = {"status": "unhealthy", "error": str(e)} try: await redis_client.ping() checks["redis"] = {"status": "healthy"} except Exception as e: checks["redis"] = {"status": "unhealthy", "error": str(e)} cpu_percent = psutil.cpu_percent() memory_percent = psutil.virtual_memory().percent checks["system"] = { "status": "healthy" if cpu_percent < 80 and memory_percent < 80 else "warning", "cpu_percent": cpu_percent, "memory_percent": memory_percent } overall_status = "healthy" if all( check["status"] == "healthy" for check in checks.values() ) else "unhealthy" return { "status": overall_status, "checks": checks, "timestamp": time.time() }
@app.get("/metrics") async def get_metrics(): """应用指标端点""" return { "active_connections": len(engine.pool.checkedout()), "pool_size": engine.pool.size(), "checked_in_connections": len(engine.pool.checkedin()), "memory_usage": psutil.Process().memory_info().rss / 1024 / 1024, "cpu_percent": psutil.Process().cpu_percent(), }
|
部署优化
1. Gunicorn配置优化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| import multiprocessing
bind = "0.0.0.0:8000" workers = multiprocessing.cpu_count() * 2 + 1 worker_class = "uvicorn.workers.UvicornWorker" worker_connections = 1000
max_requests = 1000 max_requests_jitter = 100 preload_app = True
timeout = 30 keepalive = 5 graceful_timeout = 30
accesslog = "-" errorlog = "-" loglevel = "info" access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s" %(D)s'
proc_name = "fastapi_app"
def when_ready(server): server.log.info("Server is ready. Spawning workers")
def worker_int(worker): worker.log.info("worker received INT or QUIT signal")
def pre_fork(server, worker): server.log.info("Worker spawned (pid: %s)", worker.pid)
|
2. Docker优化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| FROM python:3.11-slim
RUN apt-get update && apt-get install -y \ gcc \ && rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
RUN useradd --create-home --shell /bin/bash app \ && chown -R app:app /app USER app
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \ CMD curl -f http://localhost:8000/health || exit 1
CMD ["gunicorn", "-c", "gunicorn.conf.py", "main:app"]
|
总结
通过这些优化技巧,我们的FastAPI应用性能得到了显著提升:
- 异步编程:正确使用async/await,避免阻塞操作
- 数据库优化:连接池配置、查询优化、批量操作
- 缓存策略:多级缓存、预热机制、失效策略
- 响应优化:压缩、分页、流式响应
- 监控分析:性能监控、健康检查、指标收集
- 部署优化:Gunicorn配置、Docker优化
记住,性能优化是一个持续的过程,需要根据实际业务场景和监控数据来调整策略。不要过早优化,先确保功能正确,再根据性能瓶颈进行针对性优化。
你在FastAPI性能优化方面有什么经验或遇到过什么问题吗?欢迎在评论中分享讨论!