FastAPI异步数据库操作优化:从SQLAlchemy到性能调优
Orion K Lv6

FastAPI异步数据库操作优化:从SQLAlchemy到性能调优

在开发高性能的FastAPI应用时,数据库操作往往是性能瓶颈的关键所在。作为一个在生产环境中维护多个FastAPI项目的开发者,我想分享一些在异步数据库操作方面的实战经验和优化技巧。

异步数据库连接的正确姿势

1. 选择合适的异步数据库驱动

首先,我们需要选择支持异步操作的数据库驱动。以PostgreSQL为例:

1
2
3
4
5
6
# requirements.txt
fastapi==0.104.1
sqlalchemy==2.0.23
asyncpg==0.29.0 # PostgreSQL异步驱动
databases==0.8.0
alembic==1.12.1

2. 配置异步SQLAlchemy引擎

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# database.py
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.orm import declarative_base
import os

DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql+asyncpg://user:password@localhost/dbname"
)

# 创建异步引擎
engine = create_async_engine(
DATABASE_URL,
echo=True, # 开发环境下启用SQL日志
pool_size=20, # 连接池大小
max_overflow=0, # 超出连接池大小的连接数
pool_pre_ping=True, # 连接前检查连接是否有效
pool_recycle=3600, # 连接回收时间(秒)
)

# 创建异步会话工厂
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)

Base = declarative_base()

# 依赖注入:获取数据库会话
async def get_db():
async with AsyncSessionLocal() as session:
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()

模型定义和关系处理

1. 定义异步友好的模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# models.py
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Text, Boolean
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from database import Base

class User(Base):
__tablename__ = "users"

id = Column(Integer, primary_key=True, index=True)
username = Column(String(50), unique=True, index=True, nullable=False)
email = Column(String(100), unique=True, index=True, nullable=False)
hashed_password = Column(String(255), nullable=False)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())

# 关系定义
posts = relationship("Post", back_populates="author", lazy="selectin")
profile = relationship("UserProfile", back_populates="user", uselist=False)

class Post(Base):
__tablename__ = "posts"

id = Column(Integer, primary_key=True, index=True)
title = Column(String(200), nullable=False, index=True)
content = Column(Text, nullable=False)
author_id = Column(Integer, ForeignKey("users.id"), nullable=False)
is_published = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())

# 关系定义
author = relationship("User", back_populates="posts")
tags = relationship("Tag", secondary="post_tags", back_populates="posts")

class UserProfile(Base):
__tablename__ = "user_profiles"

id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), unique=True, nullable=False)
bio = Column(Text)
avatar_url = Column(String(255))

user = relationship("User", back_populates="profile")

2. 处理N+1查询问题

N+1查询是ORM中常见的性能问题,在异步环境下更需要注意:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# crud.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import select
from models import User, Post
from typing import List, Optional

class UserCRUD:
def __init__(self, db: AsyncSession):
self.db = db

async def get_user_with_posts(self, user_id: int) -> Optional[User]:
"""获取用户及其所有文章(避免N+1查询)"""
stmt = select(User).options(
selectinload(User.posts), # 预加载文章
selectinload(User.profile) # 预加载用户资料
).where(User.id == user_id)

result = await self.db.execute(stmt)
return result.scalar_one_or_none()

async def get_users_with_post_count(self) -> List[User]:
"""获取用户列表及文章数量"""
from sqlalchemy import func

stmt = select(
User,
func.count(Post.id).label("post_count")
).outerjoin(Post).group_by(User.id)

result = await self.db.execute(stmt)
return result.all()

async def get_active_users_with_recent_posts(self) -> List[User]:
"""获取活跃用户及其最近的文章"""
from datetime import datetime, timedelta

recent_date = datetime.utcnow() - timedelta(days=30)

stmt = select(User).options(
selectinload(User.posts.and_(Post.created_at >= recent_date))
).where(User.is_active == True)

result = await self.db.execute(stmt)
return result.scalars().all()

高级查询优化技巧

1. 批量操作优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# bulk_operations.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.dialects.postgresql import insert
from models import User, Post
from typing import List, Dict, Any

class BulkOperations:
def __init__(self, db: AsyncSession):
self.db = db

async def bulk_create_users(self, users_data: List[Dict[str, Any]]) -> List[User]:
"""批量创建用户"""
users = [User(**user_data) for user_data in users_data]
self.db.add_all(users)
await self.db.commit()

# 刷新以获取生成的ID
for user in users:
await self.db.refresh(user)

return users

async def bulk_update_users(self, updates: List[Dict[str, Any]]) -> int:
"""批量更新用户"""
stmt = update(User)
result = await self.db.execute(stmt, updates)
await self.db.commit()
return result.rowcount

async def upsert_users(self, users_data: List[Dict[str, Any]]) -> None:
"""PostgreSQL特有的UPSERT操作"""
stmt = insert(User).values(users_data)
stmt = stmt.on_conflict_do_update(
index_elements=['email'],
set_=dict(
username=stmt.excluded.username,
updated_at=func.now()
)
)
await self.db.execute(stmt)
await self.db.commit()

async def bulk_delete_inactive_users(self, days: int = 365) -> int:
"""批量删除非活跃用户"""
from datetime import datetime, timedelta

cutoff_date = datetime.utcnow() - timedelta(days=days)

stmt = delete(User).where(
User.is_active == False,
User.updated_at < cutoff_date
)

result = await self.db.execute(stmt)
await self.db.commit()
return result.rowcount

2. 分页查询优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# pagination.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, desc
from models import Post, User
from typing import List, Optional, Tuple
from pydantic import BaseModel

class PaginationParams(BaseModel):
page: int = 1
size: int = 20

@property
def offset(self) -> int:
return (self.page - 1) * self.size

class PaginatedResponse(BaseModel):
items: List[Any]
total: int
page: int
size: int
pages: int

class PostCRUD:
def __init__(self, db: AsyncSession):
self.db = db

async def get_posts_paginated(
self,
pagination: PaginationParams,
search: Optional[str] = None,
author_id: Optional[int] = None
) -> PaginatedResponse:
"""分页获取文章列表"""

# 构建基础查询
base_query = select(Post).options(
selectinload(Post.author),
selectinload(Post.tags)
)

# 添加过滤条件
if search:
base_query = base_query.where(
Post.title.ilike(f"%{search}%") |
Post.content.ilike(f"%{search}%")
)

if author_id:
base_query = base_query.where(Post.author_id == author_id)

# 只查询已发布的文章
base_query = base_query.where(Post.is_published == True)

# 获取总数
count_query = select(func.count()).select_from(
base_query.subquery()
)
total_result = await self.db.execute(count_query)
total = total_result.scalar()

# 获取分页数据
paginated_query = base_query.order_by(
desc(Post.created_at)
).offset(pagination.offset).limit(pagination.size)

result = await self.db.execute(paginated_query)
posts = result.scalars().all()

return PaginatedResponse(
items=posts,
total=total,
page=pagination.page,
size=pagination.size,
pages=(total + pagination.size - 1) // pagination.size
)

async def get_posts_cursor_paginated(
self,
cursor: Optional[int] = None,
size: int = 20
) -> Tuple[List[Post], Optional[int]]:
"""游标分页(适合大数据集)"""
query = select(Post).options(
selectinload(Post.author)
).where(Post.is_published == True)

if cursor:
query = query.where(Post.id < cursor)

query = query.order_by(desc(Post.id)).limit(size + 1)

result = await self.db.execute(query)
posts = result.scalars().all()

next_cursor = None
if len(posts) > size:
posts = posts[:-1] # 移除多查询的一条记录
next_cursor = posts[-1].id

return posts, next_cursor

连接池和事务管理

1. 连接池监控和调优

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# monitoring.py
import asyncio
import logging
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.pool import QueuePool
from contextlib import asynccontextmanager

logger = logging.getLogger(__name__)

class DatabaseMonitor:
def __init__(self, engine: AsyncEngine):
self.engine = engine

async def get_pool_status(self) -> dict:
"""获取连接池状态"""
pool = self.engine.pool
return {
"size": pool.size(),
"checked_in": pool.checkedin(),
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"invalid": pool.invalid()
}

async def monitor_pool(self, interval: int = 60):
"""定期监控连接池状态"""
while True:
try:
status = await self.get_pool_status()
logger.info(f"数据库连接池状态: {status}")

# 检查连接池是否接近满载
if status["checked_out"] / (status["size"] + status["overflow"]) > 0.8:
logger.warning("数据库连接池使用率过高!")

await asyncio.sleep(interval)
except Exception as e:
logger.error(f"监控连接池时出错: {e}")
await asyncio.sleep(interval)

# 在FastAPI应用启动时启动监控
@asynccontextmanager
async def lifespan(app):
# 启动时
monitor = DatabaseMonitor(engine)
monitor_task = asyncio.create_task(monitor.monitor_pool())

yield

# 关闭时
monitor_task.cancel()
await engine.dispose()

2. 事务管理最佳实践

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# transaction_manager.py
from sqlalchemy.ext.asyncio import AsyncSession
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import logging

logger = logging.getLogger(__name__)

@asynccontextmanager
async def transaction_manager(
db: AsyncSession,
rollback_on_exception: bool = True
) -> AsyncGenerator[AsyncSession, None]:
"""事务管理器"""
try:
yield db
await db.commit()
logger.debug("事务提交成功")
except Exception as e:
if rollback_on_exception:
await db.rollback()
logger.error(f"事务回滚: {e}")
raise
finally:
await db.close()

class TransactionalService:
def __init__(self, db: AsyncSession):
self.db = db

async def create_user_with_profile(
self,
user_data: dict,
profile_data: dict
) -> User:
"""创建用户和用户资料(事务操作)"""
async with transaction_manager(self.db) as session:
# 创建用户
user = User(**user_data)
session.add(user)
await session.flush() # 获取用户ID但不提交

# 创建用户资料
profile_data["user_id"] = user.id
profile = UserProfile(**profile_data)
session.add(profile)

# 如果这里出现异常,整个事务会回滚
return user

async def transfer_posts(
self,
from_user_id: int,
to_user_id: int
) -> int:
"""转移文章所有权(事务操作)"""
async with transaction_manager(self.db) as session:
# 检查用户是否存在
from_user = await session.get(User, from_user_id)
to_user = await session.get(User, to_user_id)

if not from_user or not to_user:
raise ValueError("用户不存在")

# 更新文章所有权
stmt = update(Post).where(
Post.author_id == from_user_id
).values(author_id=to_user_id)

result = await session.execute(stmt)
return result.rowcount

性能优化实战技巧

1. 查询缓存策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# cache.py
import redis.asyncio as redis
import json
import pickle
from typing import Any, Optional, Callable
from functools import wraps
import hashlib

class AsyncRedisCache:
def __init__(self, redis_url: str = "redis://localhost:6379"):
self.redis = redis.from_url(redis_url)

async def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
try:
data = await self.redis.get(key)
if data:
return pickle.loads(data)
except Exception as e:
logger.error(f"获取缓存失败: {e}")
return None

async def set(
self,
key: str,
value: Any,
expire: int = 3600
) -> bool:
"""设置缓存"""
try:
data = pickle.dumps(value)
return await self.redis.set(key, data, ex=expire)
except Exception as e:
logger.error(f"设置缓存失败: {e}")
return False

async def delete(self, key: str) -> bool:
"""删除缓存"""
try:
return await self.redis.delete(key) > 0
except Exception as e:
logger.error(f"删除缓存失败: {e}")
return False

# 缓存装饰器
def cache_result(expire: int = 3600, key_prefix: str = ""):
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# 生成缓存键
cache_key = f"{key_prefix}:{func.__name__}:"
key_data = str(args) + str(sorted(kwargs.items()))
cache_key += hashlib.md5(key_data.encode()).hexdigest()

# 尝试从缓存获取
cached_result = await cache.get(cache_key)
if cached_result is not None:
return cached_result

# 执行函数并缓存结果
result = await func(*args, **kwargs)
await cache.set(cache_key, result, expire)
return result

return wrapper
return decorator

# 全局缓存实例
cache = AsyncRedisCache()

# 使用缓存的服务
class CachedPostService:
def __init__(self, db: AsyncSession):
self.db = db

@cache_result(expire=1800, key_prefix="posts")
async def get_popular_posts(self, limit: int = 10) -> List[Post]:
"""获取热门文章(缓存30分钟)"""
stmt = select(Post).options(
selectinload(Post.author)
).where(
Post.is_published == True
).order_by(
desc(Post.view_count)
).limit(limit)

result = await self.db.execute(stmt)
return result.scalars().all()

async def invalidate_post_cache(self, post_id: int):
"""使文章相关缓存失效"""
patterns = [
f"posts:get_popular_posts:*",
f"posts:get_post_by_id:*{post_id}*",
]

for pattern in patterns:
keys = await cache.redis.keys(pattern)
if keys:
await cache.redis.delete(*keys)

2. 数据库索引优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# migrations/versions/add_performance_indexes.py
"""添加性能优化索引

Revision ID: 001
Create Date: 2023-08-15 14:30:00.000000
"""

from alembic import op
import sqlalchemy as sa

def upgrade():
# 复合索引:按作者和发布状态查询文章
op.create_index(
'idx_posts_author_published',
'posts',
['author_id', 'is_published', 'created_at'],
postgresql_using='btree'
)

# 部分索引:只为已发布的文章创建索引
op.execute("""
CREATE INDEX idx_posts_published_created
ON posts (created_at DESC)
WHERE is_published = true
""")

# 全文搜索索引
op.execute("""
CREATE INDEX idx_posts_fulltext
ON posts
USING gin(to_tsvector('english', title || ' ' || content))
""")

# 用户邮箱的唯一索引(如果还没有)
op.create_index(
'idx_users_email_unique',
'users',
['email'],
unique=True
)

def downgrade():
op.drop_index('idx_posts_author_published')
op.execute("DROP INDEX IF EXISTS idx_posts_published_created")
op.execute("DROP INDEX IF EXISTS idx_posts_fulltext")
op.drop_index('idx_users_email_unique')

FastAPI路由集成

1. 异步路由实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# routers/posts.py
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from crud import PostCRUD, UserCRUD
from schemas import PostCreate, PostResponse, PaginatedPostResponse
from typing import Optional, List

router = APIRouter(prefix="/posts", tags=["posts"])

@router.get("/", response_model=PaginatedPostResponse)
async def get_posts(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None),
author_id: Optional[int] = Query(None),
db: AsyncSession = Depends(get_db)
):
"""获取文章列表(分页)"""
post_crud = PostCRUD(db)
pagination = PaginationParams(page=page, size=size)

result = await post_crud.get_posts_paginated(
pagination=pagination,
search=search,
author_id=author_id
)

return result

@router.get("/{post_id}", response_model=PostResponse)
async def get_post(
post_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取单篇文章"""
post_crud = PostCRUD(db)
post = await post_crud.get_post_by_id(post_id)

if not post:
raise HTTPException(status_code=404, detail="文章不存在")

return post

@router.post("/", response_model=PostResponse, status_code=201)
async def create_post(
post_data: PostCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建文章"""
post_crud = PostCRUD(db)

# 创建文章数据
post_dict = post_data.dict()
post_dict["author_id"] = current_user.id

post = await post_crud.create_post(post_dict)

# 使相关缓存失效
cached_service = CachedPostService(db)
await cached_service.invalidate_post_cache(post.id)

return post

2. 错误处理和日志记录

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# error_handlers.py
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
import logging

logger = logging.getLogger(__name__)

async def database_exception_handler(request: Request, exc: SQLAlchemyError):
"""数据库异常处理器"""
logger.error(f"数据库错误: {exc}")

if isinstance(exc, IntegrityError):
return JSONResponse(
status_code=400,
content={
"detail": "数据完整性错误,可能是重复的唯一字段",
"type": "integrity_error"
}
)

return JSONResponse(
status_code=500,
content={
"detail": "数据库操作失败",
"type": "database_error"
}
)

# 在main.py中注册异常处理器
from fastapi import FastAPI
from sqlalchemy.exc import SQLAlchemyError

app = FastAPI()
app.add_exception_handler(SQLAlchemyError, database_exception_handler)

性能监控和调试

1. 查询性能分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# profiling.py
import time
import logging
from sqlalchemy import event
from sqlalchemy.engine import Engine
from contextlib import asynccontextmanager

logger = logging.getLogger(__name__)

class QueryProfiler:
def __init__(self):
self.slow_query_threshold = 1.0 # 慢查询阈值(秒)

def setup_query_logging(self, engine: Engine):
"""设置查询日志记录"""

@event.listens_for(engine.sync_engine, "before_cursor_execute")
def receive_before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
context._query_start_time = time.time()

@event.listens_for(engine.sync_engine, "after_cursor_execute")
def receive_after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
total = time.time() - context._query_start_time

if total > self.slow_query_threshold:
logger.warning(
f"慢查询检测 - 执行时间: {total:.2f}s\n"
f"SQL: {statement}\n"
f"参数: {parameters}"
)
else:
logger.debug(f"查询执行时间: {total:.2f}s")

# 使用性能分析器
profiler = QueryProfiler()
profiler.setup_query_logging(engine)

2. 应用性能监控

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# monitoring.py
from fastapi import Request, Response
from time import time
import logging

logger = logging.getLogger(__name__)

async def add_process_time_header(request: Request, call_next):
"""添加请求处理时间头"""
start_time = time()
response = await call_next(request)
process_time = time() - start_time

response.headers["X-Process-Time"] = str(process_time)

# 记录慢请求
if process_time > 2.0:
logger.warning(
f"慢请求检测 - {request.method} {request.url} - "
f"处理时间: {process_time:.2f}s"
)

return response

# 在main.py中添加中间件
app.middleware("http")(add_process_time_header)

总结

通过以上的优化策略,我们可以显著提升FastAPI应用的数据库操作性能:

  1. 正确配置异步数据库连接:选择合适的驱动和连接池参数
  2. 避免N+1查询:使用预加载和连接查询
  3. 实现高效的分页:根据数据量选择偏移分页或游标分页
  4. 合理使用缓存:缓存热点数据和查询结果
  5. 优化数据库索引:创建复合索引和部分索引
  6. 监控性能指标:及时发现和解决性能问题

这些技巧在我的生产环境中都得到了验证,希望能帮助你构建更高性能的FastAPI应用。记住,性能优化是一个持续的过程,需要根据实际的业务场景和数据特点来调整策略。

你在FastAPI数据库优化方面有什么经验或问题吗?欢迎在评论中分享讨论!

本站由 提供部署服务