FastAPI测试策略:从单元测试到集成测试的完整指南
Orion K Lv6

FastAPI测试策略:从单元测试到集成测试的完整指南

作为一个在多个FastAPI项目中实施测试策略的开发者,我深知良好的测试覆盖率对于维护代码质量和项目稳定性的重要性。在这篇文章中,我将分享在FastAPI应用中实施全面测试策略的实战经验。

测试环境搭建

1. 测试依赖配置

1
2
3
4
5
6
7
8
9
10
# requirements-test.txt
pytest==7.4.3
pytest-asyncio==0.21.1
pytest-cov==4.1.0
httpx==0.25.2
pytest-mock==3.12.0
factory-boy==3.3.0
faker==20.1.0
pytest-xdist==3.5.0 # 并行测试
pytest-html==4.1.1 # HTML报告

2. 测试配置文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# pytest.ini
[tool:pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
--strict-markers
--strict-config
--verbose
--cov=app
--cov-report=term-missing
--cov-report=html:htmlcov
--cov-report=xml
--cov-fail-under=80
markers =
unit: Unit tests
integration: Integration tests
e2e: End-to-end tests
slow: Slow running tests
auth: Authentication related tests
database: Database related tests

# conftest.py
import pytest
import asyncio
from typing import AsyncGenerator, Generator
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from fastapi.testclient import TestClient

from app.main import app
from app.database import get_db, Base
from app.config import settings

# 测试数据库URL
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test.db"

# 创建测试引擎
test_engine = create_async_engine(
TEST_DATABASE_URL,
echo=False,
future=True
)

TestingSessionLocal = sessionmaker(
test_engine,
class_=AsyncSession,
expire_on_commit=False
)

@pytest.fixture(scope="session")
def event_loop():
"""创建事件循环"""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()

@pytest.fixture(scope="function")
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""创建测试数据库会话"""
async with test_engine.begin() as connection:
await connection.run_sync(Base.metadata.create_all)

async with TestingSessionLocal() as session:
yield session

await connection.run_sync(Base.metadata.drop_all)

@pytest.fixture(scope="function")
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
"""创建测试客户端"""
def override_get_db():
yield db_session

app.dependency_overrides[get_db] = override_get_db

async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac

app.dependency_overrides.clear()

@pytest.fixture
def sync_client() -> Generator[TestClient, None, None]:
"""同步测试客户端(用于某些特殊情况)"""
with TestClient(app) as client:
yield client

单元测试实践

1. 模型测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# tests/test_models.py
import pytest
from datetime import datetime
from app.models.user import User
from app.models.post import Post
from app.auth.password import PasswordManager

class TestUserModel:
"""用户模型测试"""

def test_user_creation(self):
"""测试用户创建"""
user = User(
username="testuser",
email="test@example.com",
hashed_password="hashed_password"
)

assert user.username == "testuser"
assert user.email == "test@example.com"
assert user.is_active is True
assert user.is_superuser is False

def test_password_hashing(self):
"""测试密码哈希"""
user = User(
username="testuser",
email="test@example.com"
)

password = "testpassword123"
user.set_password(password)

assert user.hashed_password != password
assert user.verify_password(password) is True
assert user.verify_password("wrongpassword") is False

def test_user_is_locked_property(self):
"""测试用户锁定状态"""
user = User(
username="testuser",
email="test@example.com",
locked_until=None
)

assert user.is_locked is False

# 设置锁定时间为未来
from datetime import datetime, timedelta
user.locked_until = datetime.utcnow() + timedelta(hours=1)
assert user.is_locked is True

# 设置锁定时间为过去
user.locked_until = datetime.utcnow() - timedelta(hours=1)
assert user.is_locked is False

class TestPostModel:
"""文章模型测试"""

def test_post_creation(self):
"""测试文章创建"""
post = Post(
title="Test Post",
content="This is a test post",
author_id=1
)

assert post.title == "Test Post"
assert post.content == "This is a test post"
assert post.author_id == 1
assert post.is_published is False

def test_post_slug_generation(self):
"""测试文章slug生成"""
post = Post(
title="This is a Test Post!",
content="Content",
author_id=1
)

# 假设有slug生成逻辑
expected_slug = "this-is-a-test-post"
assert post.generate_slug() == expected_slug

2. 服务层测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# tests/test_services.py
import pytest
from unittest.mock import AsyncMock, MagicMock
from app.services.user_service import UserService
from app.services.post_service import PostService
from app.models.user import User
from app.models.post import Post
from app.schemas.user import UserCreate

class TestUserService:
"""用户服务测试"""

@pytest.fixture
def mock_db(self):
"""模拟数据库会话"""
return AsyncMock()

@pytest.fixture
def user_service(self, mock_db):
"""用户服务实例"""
return UserService(mock_db)

@pytest.mark.asyncio
async def test_create_user_success(self, user_service, mock_db):
"""测试成功创建用户"""
# 准备测试数据
user_data = UserCreate(
username="testuser",
email="test@example.com",
password="testpassword123"
)

# 模拟数据库操作
mock_db.execute.return_value.scalar_one_or_none.return_value = None # 用户不存在
mock_db.add = MagicMock()
mock_db.commit = AsyncMock()
mock_db.refresh = AsyncMock()

# 执行测试
result = await user_service.create_user(user_data)

# 验证结果
assert result.username == "testuser"
assert result.email == "test@example.com"
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()

@pytest.mark.asyncio
async def test_create_user_duplicate_email(self, user_service, mock_db):
"""测试创建重复邮箱用户"""
user_data = UserCreate(
username="testuser",
email="test@example.com",
password="testpassword123"
)

# 模拟用户已存在
existing_user = User(
id=1,
username="existing",
email="test@example.com"
)
mock_db.execute.return_value.scalar_one_or_none.return_value = existing_user

# 验证抛出异常
with pytest.raises(ValueError, match="Email already registered"):
await user_service.create_user(user_data)

@pytest.mark.asyncio
async def test_authenticate_user_success(self, user_service, mock_db):
"""测试用户认证成功"""
# 准备测试数据
user = User(
id=1,
username="testuser",
email="test@example.com",
hashed_password="hashed_password",
is_active=True
)
user.verify_password = MagicMock(return_value=True)

mock_db.execute.return_value.scalar_one_or_none.return_value = user

# 执行测试
result = await user_service.authenticate_user("testuser", "password")

# 验证结果
assert result == user
user.verify_password.assert_called_once_with("password")

@pytest.mark.asyncio
async def test_authenticate_user_wrong_password(self, user_service, mock_db):
"""测试用户认证密码错误"""
user = User(
id=1,
username="testuser",
email="test@example.com",
hashed_password="hashed_password",
is_active=True
)
user.verify_password = MagicMock(return_value=False)

mock_db.execute.return_value.scalar_one_or_none.return_value = user

# 执行测试
result = await user_service.authenticate_user("testuser", "wrongpassword")

# 验证结果
assert result is None

class TestPostService:
"""文章服务测试"""

@pytest.fixture
def mock_db(self):
return AsyncMock()

@pytest.fixture
def post_service(self, mock_db):
return PostService(mock_db)

@pytest.mark.asyncio
async def test_get_posts_with_pagination(self, post_service, mock_db):
"""测试分页获取文章"""
# 准备测试数据
posts = [
Post(id=1, title="Post 1", content="Content 1", author_id=1),
Post(id=2, title="Post 2", content="Content 2", author_id=1),
]

mock_db.execute.return_value.scalars.return_value.all.return_value = posts
mock_db.execute.return_value.scalar.return_value = 10 # 总数

# 执行测试
result = await post_service.get_posts(page=1, size=2)

# 验证结果
assert len(result.items) == 2
assert result.total == 10
assert result.page == 1
assert result.size == 2

@pytest.mark.asyncio
async def test_create_post(self, post_service, mock_db):
"""测试创建文章"""
from app.schemas.post import PostCreate

post_data = PostCreate(
title="New Post",
content="This is a new post",
tags=["python", "fastapi"]
)

mock_db.add = MagicMock()
mock_db.commit = AsyncMock()
mock_db.refresh = AsyncMock()

# 执行测试
result = await post_service.create_post(post_data, author_id=1)

# 验证结果
assert result.title == "New Post"
assert result.content == "This is a new post"
assert result.author_id == 1
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()

3. 工具函数测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# tests/test_utils.py
import pytest
from datetime import datetime, timedelta
from app.utils.security import create_access_token, verify_token
from app.utils.email import validate_email, send_email
from app.utils.text import slugify, truncate_text

class TestSecurityUtils:
"""安全工具测试"""

def test_create_and_verify_token(self):
"""测试令牌创建和验证"""
data = {"sub": "user123", "email": "test@example.com"}
token = create_access_token(data)

assert token is not None
assert isinstance(token, str)

# 验证令牌
payload = verify_token(token)
assert payload["sub"] == "user123"
assert payload["email"] == "test@example.com"
assert "exp" in payload
assert "iat" in payload

def test_verify_expired_token(self):
"""测试过期令牌验证"""
data = {"sub": "user123"}
# 创建已过期的令牌
token = create_access_token(
data,
expires_delta=timedelta(seconds=-1)
)

payload = verify_token(token)
assert payload is None

def test_verify_invalid_token(self):
"""测试无效令牌验证"""
invalid_token = "invalid.token.here"
payload = verify_token(invalid_token)
assert payload is None

class TestEmailUtils:
"""邮箱工具测试"""

@pytest.mark.parametrize("email,expected", [
("test@example.com", True),
("user.name@domain.co.uk", True),
("invalid-email", False),
("@domain.com", False),
("user@", False),
("", False),
])
def test_validate_email(self, email, expected):
"""测试邮箱验证"""
assert validate_email(email) == expected

@pytest.mark.asyncio
async def test_send_email_success(self, mocker):
"""测试发送邮件成功"""
# 模拟邮件发送
mock_send = mocker.patch('app.utils.email.smtp_client.send_message')
mock_send.return_value = True

result = await send_email(
to="test@example.com",
subject="Test Subject",
body="Test Body"
)

assert result is True
mock_send.assert_called_once()

class TestTextUtils:
"""文本工具测试"""

@pytest.mark.parametrize("text,expected", [
("Hello World", "hello-world"),
("Python & FastAPI", "python-fastapi"),
("Test with 123 Numbers!", "test-with-123-numbers"),
("中文测试", "中文测试"), # 支持中文
("", ""),
])
def test_slugify(self, text, expected):
"""测试文本slug化"""
assert slugify(text) == expected

@pytest.mark.parametrize("text,length,expected", [
("Short text", 20, "Short text"),
("This is a very long text that should be truncated", 20, "This is a very long..."),
("", 10, ""),
("Exact length text!!", 20, "Exact length text!!"),
])
def test_truncate_text(self, text, length, expected):
"""测试文本截断"""
assert truncate_text(text, length) == expected

集成测试实践

1. API端点测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# tests/test_api.py
import pytest
from httpx import AsyncClient
from app.models.user import User
from app.models.post import Post

class TestAuthAPI:
"""认证API测试"""

@pytest.mark.asyncio
async def test_register_user(self, client: AsyncClient):
"""测试用户注册"""
user_data = {
"username": "testuser",
"email": "test@example.com",
"password": "testpassword123"
}

response = await client.post("/auth/register", json=user_data)

assert response.status_code == 201
data = response.json()
assert data["username"] == "testuser"
assert data["email"] == "test@example.com"
assert "id" in data
assert "password" not in data # 确保密码不在响应中

@pytest.mark.asyncio
async def test_register_duplicate_email(self, client: AsyncClient, db_session):
"""测试注册重复邮箱"""
# 先创建一个用户
user = User(
username="existing",
email="test@example.com",
hashed_password="hashed"
)
db_session.add(user)
await db_session.commit()

# 尝试注册相同邮箱
user_data = {
"username": "newuser",
"email": "test@example.com",
"password": "testpassword123"
}

response = await client.post("/auth/register", json=user_data)

assert response.status_code == 400
assert "already registered" in response.json()["detail"]

@pytest.mark.asyncio
async def test_login_success(self, client: AsyncClient, db_session):
"""测试登录成功"""
# 创建测试用户
user = User(
username="testuser",
email="test@example.com"
)
user.set_password("testpassword123")
db_session.add(user)
await db_session.commit()

# 登录
login_data = {
"username": "testuser",
"password": "testpassword123"
}

response = await client.post("/auth/login", data=login_data)

assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"

@pytest.mark.asyncio
async def test_login_invalid_credentials(self, client: AsyncClient):
"""测试登录凭据无效"""
login_data = {
"username": "nonexistent",
"password": "wrongpassword"
}

response = await client.post("/auth/login", data=login_data)

assert response.status_code == 401
assert "Incorrect username or password" in response.json()["detail"]

@pytest.mark.asyncio
async def test_get_current_user(self, client: AsyncClient, db_session):
"""测试获取当前用户信息"""
# 创建用户并获取令牌
user = User(
username="testuser",
email="test@example.com"
)
user.set_password("testpassword123")
db_session.add(user)
await db_session.commit()

# 登录获取令牌
login_response = await client.post("/auth/login", data={
"username": "testuser",
"password": "testpassword123"
})
token = login_response.json()["access_token"]

# 获取用户信息
headers = {"Authorization": f"Bearer {token}"}
response = await client.get("/auth/me", headers=headers)

assert response.status_code == 200
data = response.json()
assert data["username"] == "testuser"
assert data["email"] == "test@example.com"

@pytest.mark.asyncio
async def test_protected_route_without_token(self, client: AsyncClient):
"""测试未认证访问受保护路由"""
response = await client.get("/auth/me")

assert response.status_code == 401

class TestPostAPI:
"""文章API测试"""

@pytest.fixture
async def authenticated_user(self, client: AsyncClient, db_session):
"""创建认证用户"""
user = User(
username="testuser",
email="test@example.com"
)
user.set_password("testpassword123")
db_session.add(user)
await db_session.commit()

# 登录获取令牌
login_response = await client.post("/auth/login", data={
"username": "testuser",
"password": "testpassword123"
})
token = login_response.json()["access_token"]

return {
"user": user,
"token": token,
"headers": {"Authorization": f"Bearer {token}"}
}

@pytest.mark.asyncio
async def test_create_post(self, client: AsyncClient, authenticated_user):
"""测试创建文章"""
post_data = {
"title": "Test Post",
"content": "This is a test post content",
"tags": ["python", "fastapi"]
}

response = await client.post(
"/posts/",
json=post_data,
headers=authenticated_user["headers"]
)

assert response.status_code == 201
data = response.json()
assert data["title"] == "Test Post"
assert data["content"] == "This is a test post content"
assert data["author_id"] == authenticated_user["user"].id
assert len(data["tags"]) == 2

@pytest.mark.asyncio
async def test_get_posts(self, client: AsyncClient, db_session):
"""测试获取文章列表"""
# 创建测试文章
user = User(username="author", email="author@example.com")
user.set_password("password")
db_session.add(user)
await db_session.flush()

posts = [
Post(
title=f"Post {i}",
content=f"Content {i}",
author_id=user.id,
is_published=True
)
for i in range(5)
]
db_session.add_all(posts)
await db_session.commit()

# 获取文章列表
response = await client.get("/posts/?page=1&size=3")

assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 3
assert data["total"] == 5
assert data["page"] == 1
assert data["size"] == 3

@pytest.mark.asyncio
async def test_get_post_by_id(self, client: AsyncClient, db_session):
"""测试根据ID获取文章"""
# 创建测试文章
user = User(username="author", email="author@example.com")
user.set_password("password")
db_session.add(user)
await db_session.flush()

post = Post(
title="Test Post",
content="Test Content",
author_id=user.id,
is_published=True
)
db_session.add(post)
await db_session.commit()
await db_session.refresh(post)

# 获取文章
response = await client.get(f"/posts/{post.id}")

assert response.status_code == 200
data = response.json()
assert data["title"] == "Test Post"
assert data["content"] == "Test Content"

@pytest.mark.asyncio
async def test_get_nonexistent_post(self, client: AsyncClient):
"""测试获取不存在的文章"""
response = await client.get("/posts/999")

assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()

@pytest.mark.asyncio
async def test_update_post(self, client: AsyncClient, authenticated_user, db_session):
"""测试更新文章"""
# 创建文章
post = Post(
title="Original Title",
content="Original Content",
author_id=authenticated_user["user"].id
)
db_session.add(post)
await db_session.commit()
await db_session.refresh(post)

# 更新文章
update_data = {
"title": "Updated Title",
"content": "Updated Content"
}

response = await client.put(
f"/posts/{post.id}",
json=update_data,
headers=authenticated_user["headers"]
)

assert response.status_code == 200
data = response.json()
assert data["title"] == "Updated Title"
assert data["content"] == "Updated Content"

@pytest.mark.asyncio
async def test_delete_post(self, client: AsyncClient, authenticated_user, db_session):
"""测试删除文章"""
# 创建文章
post = Post(
title="To Delete",
content="Content",
author_id=authenticated_user["user"].id
)
db_session.add(post)
await db_session.commit()
await db_session.refresh(post)

# 删除文章
response = await client.delete(
f"/posts/{post.id}",
headers=authenticated_user["headers"]
)

assert response.status_code == 204

# 验证文章已删除
get_response = await client.get(f"/posts/{post.id}")
assert get_response.status_code == 404

2. 数据库集成测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# tests/test_database.py
import pytest
from sqlalchemy import select
from app.models.user import User
from app.models.post import Post

class TestDatabaseOperations:
"""数据库操作测试"""

@pytest.mark.asyncio
async def test_user_crud_operations(self, db_session):
"""测试用户CRUD操作"""
# Create
user = User(
username="testuser",
email="test@example.com"
)
user.set_password("password123")

db_session.add(user)
await db_session.commit()
await db_session.refresh(user)

assert user.id is not None
assert user.username == "testuser"

# Read
stmt = select(User).where(User.username == "testuser")
result = await db_session.execute(stmt)
found_user = result.scalar_one_or_none()

assert found_user is not None
assert found_user.email == "test@example.com"

# Update
found_user.email = "updated@example.com"
await db_session.commit()

await db_session.refresh(found_user)
assert found_user.email == "updated@example.com"

# Delete
await db_session.delete(found_user)
await db_session.commit()

stmt = select(User).where(User.username == "testuser")
result = await db_session.execute(stmt)
deleted_user = result.scalar_one_or_none()

assert deleted_user is None

@pytest.mark.asyncio
async def test_user_post_relationship(self, db_session):
"""测试用户和文章关系"""
# 创建用户
user = User(
username="author",
email="author@example.com"
)
user.set_password("password")
db_session.add(user)
await db_session.flush()

# 创建文章
posts = [
Post(
title=f"Post {i}",
content=f"Content {i}",
author_id=user.id
)
for i in range(3)
]
db_session.add_all(posts)
await db_session.commit()

# 测试关系查询
stmt = select(User).where(User.id == user.id)
result = await db_session.execute(stmt)
user_with_posts = result.scalar_one()

# 由于使用了selectinload,posts应该被预加载
assert len(user_with_posts.posts) == 3
assert all(post.author_id == user.id for post in user_with_posts.posts)

@pytest.mark.asyncio
async def test_database_constraints(self, db_session):
"""测试数据库约束"""
# 测试唯一约束
user1 = User(
username="unique_user",
email="unique@example.com"
)
user1.set_password("password")
db_session.add(user1)
await db_session.commit()

# 尝试创建相同用户名的用户
user2 = User(
username="unique_user", # 相同用户名
email="different@example.com"
)
user2.set_password("password")
db_session.add(user2)

with pytest.raises(Exception): # 应该抛出完整性错误
await db_session.commit()

端到端测试

1. 完整用户流程测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# tests/test_e2e.py
import pytest
from httpx import AsyncClient

@pytest.mark.e2e
class TestUserJourney:
"""用户完整流程测试"""

@pytest.mark.asyncio
async def test_complete_user_journey(self, client: AsyncClient):
"""测试完整的用户使用流程"""

# 1. 用户注册
register_data = {
"username": "journeyuser",
"email": "journey@example.com",
"password": "SecurePass123!"
}

register_response = await client.post("/auth/register", json=register_data)
assert register_response.status_code == 201
user_data = register_response.json()

# 2. 用户登录
login_response = await client.post("/auth/login", data={
"username": "journeyuser",
"password": "SecurePass123!"
})
assert login_response.status_code == 200
tokens = login_response.json()
headers = {"Authorization": f"Bearer {tokens['access_token']}"}

# 3. 获取用户信息
profile_response = await client.get("/auth/me", headers=headers)
assert profile_response.status_code == 200
profile = profile_response.json()
assert profile["username"] == "journeyuser"

# 4. 创建文章
post_data = {
"title": "My First Post",
"content": "This is my first blog post!",
"tags": ["introduction", "blog"]
}

create_post_response = await client.post(
"/posts/",
json=post_data,
headers=headers
)
assert create_post_response.status_code == 201
post = create_post_response.json()

# 5. 获取文章列表
posts_response = await client.get("/posts/")
assert posts_response.status_code == 200
posts_data = posts_response.json()
assert len(posts_data["items"]) >= 1

# 6. 更新文章
update_data = {
"title": "My Updated First Post",
"content": "This is my updated first blog post!"
}

update_response = await client.put(
f"/posts/{post['id']}",
json=update_data,
headers=headers
)
assert update_response.status_code == 200
updated_post = update_response.json()
assert updated_post["title"] == "My Updated First Post"

# 7. 删除文章
delete_response = await client.delete(
f"/posts/{post['id']}",
headers=headers
)
assert delete_response.status_code == 204

# 8. 验证文章已删除
get_deleted_response = await client.get(f"/posts/{post['id']}")
assert get_deleted_response.status_code == 404

# 9. 用户登出
logout_response = await client.post("/auth/logout", headers=headers)
assert logout_response.status_code == 200

测试数据工厂

1. 使用Factory Boy创建测试数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# tests/factories.py
import factory
from factory import Faker, SubFactory
from app.models.user import User
from app.models.post import Post
from app.auth.password import PasswordManager

class UserFactory(factory.Factory):
"""用户工厂"""
class Meta:
model = User

username = Faker('user_name')
email = Faker('email')
hashed_password = factory.LazyAttribute(
lambda obj: PasswordManager.hash_password('testpassword123')
)
is_active = True
is_superuser = False
is_verified = False

class SuperUserFactory(UserFactory):
"""超级用户工厂"""
is_superuser = True
is_verified = True

class PostFactory(factory.Factory):
"""文章工厂"""
class Meta:
model = Post

title = Faker('sentence', nb_words=4)
content = Faker('text', max_nb_chars=1000)
is_published = True
author = SubFactory(UserFactory)

@factory.post_generation
def tags(self, create, extracted, **kwargs):
if not create:
return

if extracted:
for tag in extracted:
self.tags.append(tag)

# 使用工厂的测试示例
class TestWithFactories:
"""使用工厂的测试"""

@pytest.mark.asyncio
async def test_create_user_with_factory(self, db_session):
"""使用工厂创建用户测试"""
user = UserFactory()
db_session.add(user)
await db_session.commit()

assert user.username is not None
assert user.email is not None
assert user.is_active is True

@pytest.mark.asyncio
async def test_create_post_with_factory(self, db_session):
"""使用工厂创建文章测试"""
# 先创建作者
author = UserFactory()
db_session.add(author)
await db_session.flush()

# 创建文章
post = PostFactory(author_id=author.id)
db_session.add(post)
await db_session.commit()

assert post.title is not None
assert post.content is not None
assert post.author_id == author.id

@pytest.mark.asyncio
async def test_batch_create_with_factory(self, db_session):
"""批量创建测试数据"""
# 创建多个用户
users = UserFactory.create_batch(5)
db_session.add_all(users)
await db_session.flush()

# 为每个用户创建文章
posts = []
for user in users:
user_posts = PostFactory.create_batch(3, author_id=user.id)
posts.extend(user_posts)

db_session.add_all(posts)
await db_session.commit()

assert len(users) == 5
assert len(posts) == 15

性能测试

1. 负载测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# tests/test_performance.py
import pytest
import asyncio
import time
from httpx import AsyncClient

@pytest.mark.slow
class TestPerformance:
"""性能测试"""

@pytest.mark.asyncio
async def test_concurrent_requests(self, client: AsyncClient):
"""测试并发请求性能"""

async def make_request():
response = await client.get("/posts/")
return response.status_code

# 创建100个并发请求
start_time = time.time()
tasks = [make_request() for _ in range(100)]
results = await asyncio.gather(*tasks)
end_time = time.time()

# 验证所有请求都成功
assert all(status == 200 for status in results)

# 验证响应时间合理(小于5秒)
total_time = end_time - start_time
assert total_time < 5.0

print(f"100个并发请求耗时: {total_time:.2f}秒")

@pytest.mark.asyncio
async def test_database_query_performance(self, db_session):
"""测试数据库查询性能"""
# 创建大量测试数据
users = UserFactory.create_batch(1000)
db_session.add_all(users)
await db_session.commit()

# 测试查询性能
start_time = time.time()

from sqlalchemy import select
stmt = select(User).limit(100)
result = await db_session.execute(stmt)
users = result.scalars().all()

end_time = time.time()
query_time = end_time - start_time

assert len(users) == 100
assert query_time < 1.0 # 查询应该在1秒内完成

print(f"查询100个用户耗时: {query_time:.3f}秒")

@pytest.mark.asyncio
async def test_memory_usage(self, client: AsyncClient):
"""测试内存使用情况"""
import psutil
import os

process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024 # MB

# 执行大量请求
for _ in range(1000):
await client.get("/health")

final_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_increase = final_memory - initial_memory

# 内存增长应该控制在合理范围内(小于50MB)
assert memory_increase < 50

print(f"内存增长: {memory_increase:.2f}MB")

测试覆盖率和报告

1. 覆盖率配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# .coveragerc
[run]
source = app
omit =
app/tests/*
app/migrations/*
app/__init__.py
*/venv/*
*/virtualenv/*

[report]
exclude_lines =
pragma: no cover
def __repr__
raise AssertionError
raise NotImplementedError
if __name__ == .__main__.:
class .*\(Protocol\):
@(abc\.)?abstractmethod

[html]
directory = htmlcov

2. 测试报告生成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# tests/conftest.py (添加到现有文件)
import pytest

def pytest_configure(config):
"""配置pytest"""
config.addinivalue_line(
"markers", "unit: marks tests as unit tests"
)
config.addinivalue_line(
"markers", "integration: marks tests as integration tests"
)
config.addinivalue_line(
"markers", "e2e: marks tests as end-to-end tests"
)

@pytest.fixture(autouse=True)
def setup_test_environment(monkeypatch):
"""设置测试环境变量"""
monkeypatch.setenv("TESTING", "true")
monkeypatch.setenv("DATABASE_URL", "sqlite+aiosqlite:///./test.db")

def pytest_collection_modifyitems(config, items):
"""修改测试项目"""
for item in items:
# 为慢测试添加标记
if "slow" in item.keywords:
item.add_marker(pytest.mark.slow)

# 根据文件路径添加标记
if "test_api" in str(item.fspath):
item.add_marker(pytest.mark.integration)
elif "test_e2e" in str(item.fspath):
item.add_marker(pytest.mark.e2e)
else:
item.add_marker(pytest.mark.unit)

持续集成配置

1. GitHub Actions配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# .github/workflows/test.yml
name: Tests

on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest

services:
postgres:
image: postgres:13
env:
POSTGRES_PASSWORD: postgres
POSTGRES_DB: test_db
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432

redis:
image: redis:6
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 6379:6379

strategy:
matrix:
python-version: [3.9, 3.10, 3.11]

steps:
- uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-test.txt

- name: Run linting
run: |
flake8 app tests
black --check app tests
isort --check-only app tests

- name: Run type checking
run: |
mypy app

- name: Run unit tests
run: |
pytest tests/ -m "unit" --cov=app --cov-report=xml
env:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
REDIS_URL: redis://localhost:6379

- name: Run integration tests
run: |
pytest tests/ -m "integration" --cov=app --cov-append --cov-report=xml
env:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
REDIS_URL: redis://localhost:6379

- name: Run E2E tests
run: |
pytest tests/ -m "e2e" --cov=app --cov-append --cov-report=xml
env:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
REDIS_URL: redis://localhost:6379

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
name: codecov-umbrella

测试最佳实践

1. 测试组织和命名

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 好的测试组织结构
tests/
├── conftest.py # 全局配置和fixture
├── factories.py # 测试数据工厂
├── unit/ # 单元测试
│ ├── test_models.py
│ ├── test_services.py
│ └── test_utils.py
├── integration/ # 集成测试
│ ├── test_api.py
│ ├── test_database.py
│ └── test_auth.py
├── e2e/ # 端到端测试
│ └── test_user_journey.py
└── performance/ # 性能测试
└── test_load.py

# 测试命名约定
class TestUserService:
def test_create_user_success(self):
"""测试成功创建用户"""
pass

def test_create_user_duplicate_email_raises_error(self):
"""测试创建重复邮箱用户抛出错误"""
pass

def test_authenticate_user_with_valid_credentials_returns_user(self):
"""测试使用有效凭据认证用户返回用户对象"""
pass

2. 测试数据管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# tests/helpers.py
from typing import Dict, Any
from app.models.user import User
from app.models.post import Post

class TestDataHelper:
"""测试数据辅助类"""

@staticmethod
def create_test_user(db_session, **kwargs) -> User:
"""创建测试用户"""
default_data = {
"username": "testuser",
"email": "test@example.com",
"password": "testpassword123"
}
default_data.update(kwargs)

user = User(
username=default_data["username"],
email=default_data["email"]
)
user.set_password(default_data["password"])

db_session.add(user)
return user

@staticmethod
def create_test_post(db_session, author: User, **kwargs) -> Post:
"""创建测试文章"""
default_data = {
"title": "Test Post",
"content": "Test Content",
"is_published": True
}
default_data.update(kwargs)

post = Post(
title=default_data["title"],
content=default_data["content"],
author_id=author.id,
is_published=default_data["is_published"]
)

db_session.add(post)
return post

@staticmethod
async def authenticate_user(client, username: str, password: str) -> Dict[str, Any]:
"""认证用户并返回令牌"""
response = await client.post("/auth/login", data={
"username": username,
"password": password
})

if response.status_code == 200:
tokens = response.json()
return {
"tokens": tokens,
"headers": {"Authorization": f"Bearer {tokens['access_token']}"}
}

return {}

3. 测试执行策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Makefile
.PHONY: test test-unit test-integration test-e2e test-coverage test-fast

# 运行所有测试
test:
pytest tests/ -v

# 只运行单元测试
test-unit:
pytest tests/ -m "unit" -v

# 只运行集成测试
test-integration:
pytest tests/ -m "integration" -v

# 只运行端到端测试
test-e2e:
pytest tests/ -m "e2e" -v

# 运行测试并生成覆盖率报告
test-coverage:
pytest tests/ --cov=app --cov-report=html --cov-report=term-missing

# 快速测试(跳过慢测试)
test-fast:
pytest tests/ -m "not slow" -v

# 并行测试
test-parallel:
pytest tests/ -n auto

# 测试特定文件
test-file:
pytest tests/test_api.py -v

# 测试特定函数
test-function:
pytest tests/test_api.py::TestAuthAPI::test_login_success -v

总结

通过本文的实践,我们建立了一个完整的FastAPI测试策略:

  1. 测试环境搭建:配置pytest、数据库和依赖注入
  2. 单元测试:测试模型、服务和工具函数
  3. 集成测试:测试API端点和数据库操作
  4. 端到端测试:测试完整的用户流程
  5. 测试数据管理:使用工厂模式创建测试数据
  6. 性能测试:验证应用性能和资源使用
  7. 持续集成:自动化测试执行和报告

良好的测试策略不仅能够保证代码质量,还能提高开发效率和系统稳定性。记住,测试不是负担,而是开发过程中的重要投资。

你在FastAPI测试方面有什么经验或问题吗?欢迎在评论中分享讨论!

本站由 提供部署服务