FastAPI WebSocket实战:构建高性能实时应用
Orion K Lv6

FastAPI WebSocket实战:构建高性能实时应用

在现代Web应用中,实时通信已经成为必不可少的功能。无论是聊天应用、实时通知、协作编辑还是实时数据监控,WebSocket都是首选的技术方案。作为一名使用FastAPI开发实时应用超过18个月的开发者,我想分享一些实战经验和最佳实践。

WebSocket基础实现

1. 简单的WebSocket连接

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import List
import json
import time

app = FastAPI()

# 存储活跃连接
active_connections: List[WebSocket] = []

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
active_connections.append(websocket)

try:
while True:
# 接收客户端消息
data = await websocket.receive_text()
message = json.loads(data)

# 广播给所有连接的客户端
for connection in active_connections:
try:
await connection.send_text(json.dumps({
"type": "broadcast",
"message": message.get("message", ""),
"timestamp": time.time()
}))
except:
# 连接已断开,从列表中移除
active_connections.remove(connection)

except WebSocketDisconnect:
active_connections.remove(websocket)
print(f"Client disconnected. Active connections: {len(active_connections)}")

2. 连接管理器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from typing import Dict, Set
import uuid
import logging

logger = logging.getLogger(__name__)

class ConnectionManager:
def __init__(self):
# 存储所有活跃连接
self.active_connections: Dict[str, WebSocket] = {}
# 房间管理
self.rooms: Dict[str, Set[str]] = {}
# 用户连接映射
self.user_connections: Dict[str, str] = {} # user_id -> connection_id

async def connect(self, websocket: WebSocket, user_id: str = None) -> str:
"""建立连接"""
await websocket.accept()

# 生成唯一连接ID
connection_id = str(uuid.uuid4())
self.active_connections[connection_id] = websocket

# 如果提供了用户ID,建立映射
if user_id:
# 如果用户已有连接,断开旧连接
if user_id in self.user_connections:
old_connection_id = self.user_connections[user_id]
await self.disconnect(old_connection_id)

self.user_connections[user_id] = connection_id

logger.info(f"New connection: {connection_id}, User: {user_id}")
return connection_id

async def disconnect(self, connection_id: str):
"""断开连接"""
if connection_id in self.active_connections:
websocket = self.active_connections[connection_id]

# 从所有房间中移除
for room_id in list(self.rooms.keys()):
if connection_id in self.rooms[room_id]:
self.rooms[room_id].remove(connection_id)
if not self.rooms[room_id]: # 房间为空时删除
del self.rooms[room_id]

# 移除用户映射
user_id = None
for uid, cid in self.user_connections.items():
if cid == connection_id:
user_id = uid
break
if user_id:
del self.user_connections[user_id]

# 关闭连接
try:
await websocket.close()
except:
pass

del self.active_connections[connection_id]
logger.info(f"Connection disconnected: {connection_id}, User: {user_id}")

async def send_personal_message(self, message: dict, connection_id: str):
"""发送个人消息"""
if connection_id in self.active_connections:
websocket = self.active_connections[connection_id]
try:
await websocket.send_text(json.dumps(message))
except:
await self.disconnect(connection_id)

async def send_to_user(self, message: dict, user_id: str):
"""发送消息给特定用户"""
if user_id in self.user_connections:
connection_id = self.user_connections[user_id]
await self.send_personal_message(message, connection_id)

async def join_room(self, connection_id: str, room_id: str):
"""加入房间"""
if room_id not in self.rooms:
self.rooms[room_id] = set()
self.rooms[room_id].add(connection_id)
logger.info(f"Connection {connection_id} joined room {room_id}")

async def leave_room(self, connection_id: str, room_id: str):
"""离开房间"""
if room_id in self.rooms and connection_id in self.rooms[room_id]:
self.rooms[room_id].remove(connection_id)
if not self.rooms[room_id]:
del self.rooms[room_id]
logger.info(f"Connection {connection_id} left room {room_id}")

async def broadcast_to_room(self, message: dict, room_id: str, exclude_connection: str = None):
"""向房间广播消息"""
if room_id not in self.rooms:
return

disconnected_connections = []
for connection_id in self.rooms[room_id]:
if connection_id == exclude_connection:
continue

try:
websocket = self.active_connections[connection_id]
await websocket.send_text(json.dumps(message))
except:
disconnected_connections.append(connection_id)

# 清理断开的连接
for connection_id in disconnected_connections:
await self.disconnect(connection_id)

# 全局连接管理器
manager = ConnectionManager()

@app.websocket("/ws/{user_id}")
async def websocket_endpoint(websocket: WebSocket, user_id: str):
connection_id = await manager.connect(websocket, user_id)

try:
while True:
data = await websocket.receive_text()
message = json.loads(data)

# 处理不同类型的消息
await handle_message(message, connection_id, user_id)

except WebSocketDisconnect:
await manager.disconnect(connection_id)

async def handle_message(message: dict, connection_id: str, user_id: str):
"""处理WebSocket消息"""
message_type = message.get("type")

if message_type == "join_room":
room_id = message.get("room_id")
await manager.join_room(connection_id, room_id)

# 通知房间内其他用户
await manager.broadcast_to_room({
"type": "user_joined",
"user_id": user_id,
"room_id": room_id,
"timestamp": time.time()
}, room_id, exclude_connection=connection_id)

elif message_type == "room_message":
room_id = message.get("room_id")
content = message.get("content")

# 广播消息到房间
await manager.broadcast_to_room({
"type": "room_message",
"user_id": user_id,
"room_id": room_id,
"content": content,
"timestamp": time.time()
}, room_id)

实时聊天室实现

1. 聊天室数据模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from sqlalchemy import Column, Integer, String, DateTime, Text, ForeignKey
from sqlalchemy.orm import relationship
from datetime import datetime

class ChatRoom(Base):
__tablename__ = "chat_rooms"

id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
created_by = Column(Integer, ForeignKey("users.id"))

messages = relationship("ChatMessage", back_populates="room")
members = relationship("RoomMember", back_populates="room")

class ChatMessage(Base):
__tablename__ = "chat_messages"

id = Column(Integer, primary_key=True, index=True)
room_id = Column(Integer, ForeignKey("chat_rooms.id"))
user_id = Column(Integer, ForeignKey("users.id"))
content = Column(Text, nullable=False)
message_type = Column(String(20), default="text") # text, image, file
created_at = Column(DateTime, default=datetime.utcnow)

room = relationship("ChatRoom", back_populates="messages")
user = relationship("User")

class RoomMember(Base):
__tablename__ = "room_members"

id = Column(Integer, primary_key=True, index=True)
room_id = Column(Integer, ForeignKey("chat_rooms.id"))
user_id = Column(Integer, ForeignKey("users.id"))
joined_at = Column(DateTime, default=datetime.utcnow)
role = Column(String(20), default="member") # admin, moderator, member

room = relationship("ChatRoom", back_populates="members")
user = relationship("User")

2. 聊天室服务

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_

class ChatService:
def __init__(self, db: AsyncSession):
self.db = db

async def create_room(self, name: str, description: str, creator_id: int) -> ChatRoom:
"""创建聊天室"""
room = ChatRoom(
name=name,
description=description,
created_by=creator_id
)
self.db.add(room)
await self.db.commit()
await self.db.refresh(room)

# 将创建者添加为管理员
member = RoomMember(
room_id=room.id,
user_id=creator_id,
role="admin"
)
self.db.add(member)
await self.db.commit()

return room

async def save_message(self, room_id: int, user_id: int, content: str, message_type: str = "text") -> ChatMessage:
"""保存聊天消息"""
message = ChatMessage(
room_id=room_id,
user_id=user_id,
content=content,
message_type=message_type
)
self.db.add(message)
await self.db.commit()
await self.db.refresh(message)
return message

async def get_room_messages(self, room_id: int, limit: int = 50, offset: int = 0) -> List[ChatMessage]:
"""获取聊天室消息历史"""
messages = await self.db.execute(
select(ChatMessage)
.where(ChatMessage.room_id == room_id)
.order_by(ChatMessage.created_at.desc())
.offset(offset)
.limit(limit)
)
return messages.scalars().all()

实时通知系统

1. 通知管理器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from enum import Enum

class NotificationType(Enum):
INFO = "info"
WARNING = "warning"
ERROR = "error"
SUCCESS = "success"

class NotificationManager:
def __init__(self):
self.connection_manager = ConnectionManager()

async def send_notification(
self,
user_id: str,
title: str,
message: str,
notification_type: NotificationType = NotificationType.INFO,
data: dict = None
):
"""发送通知给特定用户"""
notification = {
"type": "notification",
"notification_type": notification_type.value,
"title": title,
"message": message,
"data": data or {},
"timestamp": time.time(),
"id": str(uuid.uuid4())
}

await self.connection_manager.send_to_user(notification, user_id)

async def broadcast_notification(
self,
title: str,
message: str,
notification_type: NotificationType = NotificationType.INFO
):
"""广播通知给所有用户"""
notification = {
"type": "notification",
"notification_type": notification_type.value,
"title": title,
"message": message,
"timestamp": time.time(),
"id": str(uuid.uuid4())
}

# 发送给所有连接的用户
for user_id, connection_id in self.connection_manager.user_connections.items():
await self.connection_manager.send_personal_message(notification, connection_id)

notification_manager = NotificationManager()

@app.websocket("/notifications/{user_id}")
async def notification_websocket(websocket: WebSocket, user_id: str):
connection_id = await notification_manager.connection_manager.connect(websocket, user_id)

try:
# 发送连接成功通知
await notification_manager.send_notification(
user_id,
"连接成功",
"实时通知已启用",
NotificationType.SUCCESS
)

while True:
# 保持连接活跃
data = await websocket.receive_text()
message = json.loads(data)

# 处理客户端消息(如标记通知为已读)
if message.get("type") == "mark_read":
notification_id = message.get("notification_id")
# 这里可以更新数据库中的通知状态
pass

except WebSocketDisconnect:
await notification_manager.connection_manager.disconnect(connection_id)

性能优化和监控

1. 连接限制和速率限制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from collections import defaultdict, deque

class RateLimiter:
def __init__(self, max_requests: int = 10, window_seconds: int = 60):
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests = defaultdict(deque)

def is_allowed(self, identifier: str) -> bool:
now = time.time()
window_start = now - self.window_seconds

# 清理过期的请求记录
while self.requests[identifier] and self.requests[identifier][0] < window_start:
self.requests[identifier].popleft()

# 检查是否超过限制
if len(self.requests[identifier]) >= self.max_requests:
return False

# 记录新请求
self.requests[identifier].append(now)
return True

class EnhancedConnectionManager(ConnectionManager):
def __init__(self, max_connections_per_user: int = 3):
super().__init__()
self.max_connections_per_user = max_connections_per_user
self.rate_limiter = RateLimiter(max_requests=30, window_seconds=60)
self.user_connection_count = defaultdict(int)

async def connect(self, websocket: WebSocket, user_id: str = None) -> str:
# 检查用户连接数限制
if user_id and self.user_connection_count[user_id] >= self.max_connections_per_user:
await websocket.close(code=1008, reason="Too many connections")
raise Exception("Connection limit exceeded")

connection_id = await super().connect(websocket, user_id)

if user_id:
self.user_connection_count[user_id] += 1

return connection_id

async def handle_message_with_rate_limit(self, message: dict, connection_id: str, user_id: str):
# 检查速率限制
if not self.rate_limiter.is_allowed(user_id):
await self.send_personal_message({
"type": "error",
"message": "Rate limit exceeded. Please slow down.",
"timestamp": time.time()
}, connection_id)
return

# 处理消息
await handle_message(message, connection_id, user_id)

2. 前端集成示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
class WebSocketClient {
constructor(url, userId) {
this.url = url;
this.userId = userId;
this.ws = null;
this.reconnectAttempts = 0;
this.maxReconnectAttempts = 5;
this.reconnectDelay = 1000;
this.messageHandlers = new Map();
}

connect() {
try {
this.ws = new WebSocket(`${this.url}/${this.userId}`);

this.ws.onopen = (event) => {
console.log('WebSocket connected');
this.reconnectAttempts = 0;
this.onConnect(event);
};

this.ws.onmessage = (event) => {
const message = JSON.parse(event.data);
this.handleMessage(message);
};

this.ws.onclose = (event) => {
console.log('WebSocket disconnected');
this.onDisconnect(event);
this.attemptReconnect();
};

this.ws.onerror = (error) => {
console.error('WebSocket error:', error);
this.onError(error);
};

} catch (error) {
console.error('Failed to connect:', error);
this.attemptReconnect();
}
}

send(message) {
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify(message));
} else {
console.warn('WebSocket not connected');
}
}

attemptReconnect() {
if (this.reconnectAttempts < this.maxReconnectAttempts) {
this.reconnectAttempts++;
console.log(`Attempting to reconnect (${this.reconnectAttempts}/${this.maxReconnectAttempts})`);

setTimeout(() => {
this.connect();
}, this.reconnectDelay * this.reconnectAttempts);
}
}

handleMessage(message) {
const handler = this.messageHandlers.get(message.type);
if (handler) {
handler(message);
} else {
console.log('Unhandled message:', message);
}
}

onMessage(messageType, handler) {
this.messageHandlers.set(messageType, handler);
}

// 事件钩子
onConnect(event) {}
onDisconnect(event) {}
onError(error) {}
}

// 使用示例
const chatClient = new WebSocketClient('ws://localhost:8000/ws', 'user123');

// 设置消息处理器
chatClient.onMessage('room_message', (message) => {
displayChatMessage(message);
});

chatClient.onMessage('user_joined', (message) => {
showUserJoined(message.user_id);
});

chatClient.onMessage('notification', (message) => {
showNotification(message.title, message.message, message.notification_type);
});

// 连接
chatClient.connect();

// 发送消息
function sendChatMessage(roomId, content) {
chatClient.send({
type: 'room_message',
room_id: roomId,
content: content
});
}

部署和运维

1. 生产环境配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 生产环境WebSocket配置
import os
from fastapi.middleware.cors import CORSMiddleware

# CORS配置
app.add_middleware(
CORSMiddleware,
allow_origins=os.getenv("ALLOWED_ORIGINS", "").split(","),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# 健康检查
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"active_connections": len(manager.active_connections),
"active_rooms": len(manager.rooms),
"timestamp": time.time()
}

# WebSocket统计
@app.get("/ws/stats")
async def websocket_stats():
return {
"active_connections": len(manager.active_connections),
"active_rooms": len(manager.rooms),
"connected_users": len(manager.user_connections)
}

2. Docker配置

1
2
3
4
5
6
7
8
9
10
11
12
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

EXPOSE 8000

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]

总结

通过本文的实战案例,我们学习了如何使用FastAPI构建高性能的实时应用:

核心要点

  1. 连接管理:实现健壮的连接管理器,支持房间、用户映射等功能
  2. 消息处理:设计清晰的消息类型和处理流程
  3. 数据持久化:结合数据库存储聊天记录和用户信息
  4. 性能优化:实现连接限制、速率限制和监控
  5. 错误处理:优雅处理连接断开和异常情况

最佳实践

  1. 分离关注点:将WebSocket逻辑、业务逻辑和数据访问分离
  2. 状态管理:合理管理连接状态和用户状态
  3. 安全考虑:实现认证、授权和速率限制
  4. 监控告警:添加必要的监控指标和日志
  5. 客户端重连:实现客户端自动重连机制

WebSocket开发需要考虑很多细节,但掌握了这些模式和技巧,你就能构建出稳定可靠的实时应用。记住,实时通信的关键在于连接管理和消息处理的设计,这决定了应用的性能和用户体验。

你在WebSocket开发中遇到过什么挑战吗?欢迎在评论中分享讨论!

本站由 提供部署服务