5.6 安全和权限控制
🎯 学习目标:构建企业级的安全和权限控制系统,保护MCP服务器和数据安全
⏱️ 预计时间:60分钟
📊 难度等级:⭐⭐⭐⭐⭐
🔒 安全威胁分析
🚨 常见安全风险
🛡️ 身份认证系统
🔐 多因素认证框架
python
"""
authentication.py - 身份认证系统
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Union, Tuple
from dataclasses import dataclass, field
from enum import Enum
import asyncio
import hashlib
import hmac
import secrets
import jwt
import pyotp
from datetime import datetime, timedelta
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import base64
from loguru import logger
class AuthMethod(Enum):
"""认证方式"""
PASSWORD = "password"
API_KEY = "api_key"
JWT_TOKEN = "jwt_token"
OAuth2 = "oauth2"
CERTIFICATE = "certificate"
BIOMETRIC = "biometric"
MFA = "mfa"
class AuthStatus(Enum):
"""认证状态"""
SUCCESS = "success"
FAILED = "failed"
EXPIRED = "expired"
LOCKED = "locked"
PENDING_MFA = "pending_mfa"
REVOKED = "revoked"
@dataclass
class AuthCredentials:
"""认证凭据"""
method: AuthMethod
identifier: str # 用户名、邮箱、证书指纹等
secret: str # 密码、密钥、令牌等
metadata: Dict[str, Any] = field(default_factory=dict)
expires_at: Optional[datetime] = None
@dataclass
class AuthResult:
"""认证结果"""
status: AuthStatus
user_id: Optional[str] = None
session_id: Optional[str] = None
permissions: List[str] = field(default_factory=list)
expires_at: Optional[datetime] = None
metadata: Dict[str, Any] = field(default_factory=dict)
error_message: Optional[str] = None
mfa_required: bool = False
mfa_methods: List[str] = field(default_factory=list)
class AuthProvider(ABC):
"""认证提供者抽象基类"""
@abstractmethod
async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
"""执行认证"""
pass
@abstractmethod
async def validate_session(self, session_id: str) -> AuthResult:
"""验证会话"""
pass
@abstractmethod
async def revoke_session(self, session_id: str) -> bool:
"""撤销会话"""
pass
class PasswordAuthProvider(AuthProvider):
"""密码认证提供者"""
def __init__(self):
self.failed_attempts: Dict[str, List[datetime]] = {}
self.locked_accounts: Dict[str, datetime] = {}
self.max_attempts = 5
self.lockout_duration = timedelta(minutes=30)
self.attempt_window = timedelta(minutes=15)
async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
"""密码认证"""
if credentials.method != AuthMethod.PASSWORD:
return AuthResult(
status=AuthStatus.FAILED,
error_message="不支持的认证方式"
)
identifier = credentials.identifier
password = credentials.secret
# 检查账户是否被锁定
if await self._is_account_locked(identifier):
return AuthResult(
status=AuthStatus.LOCKED,
error_message="账户已被锁定,请稍后再试"
)
try:
# 验证密码
user_info = await self._get_user_info(identifier)
if not user_info:
await self._record_failed_attempt(identifier)
return AuthResult(
status=AuthStatus.FAILED,
error_message="用户名或密码错误"
)
stored_hash = user_info["password_hash"]
salt = user_info["salt"]
if not await self._verify_password(password, stored_hash, salt):
await self._record_failed_attempt(identifier)
return AuthResult(
status=AuthStatus.FAILED,
error_message="用户名或密码错误"
)
# 清除失败记录
self.failed_attempts.pop(identifier, None)
# 检查是否需要MFA
if user_info.get("mfa_enabled"):
return AuthResult(
status=AuthStatus.PENDING_MFA,
user_id=user_info["user_id"],
mfa_required=True,
mfa_methods=user_info.get("mfa_methods", ["totp"])
)
# 创建会话
session_id = await self._create_session(user_info["user_id"])
return AuthResult(
status=AuthStatus.SUCCESS,
user_id=user_info["user_id"],
session_id=session_id,
permissions=user_info.get("permissions", []),
expires_at=datetime.now() + timedelta(hours=24)
)
except Exception as e:
logger.error(f"密码认证失败: {e}")
return AuthResult(
status=AuthStatus.FAILED,
error_message="认证服务异常"
)
async def _is_account_locked(self, identifier: str) -> bool:
"""检查账户是否被锁定"""
if identifier in self.locked_accounts:
lock_time = self.locked_accounts[identifier]
if datetime.now() - lock_time < self.lockout_duration:
return True
else:
# 锁定时间已过,解锁账户
del self.locked_accounts[identifier]
return False
async def _record_failed_attempt(self, identifier: str):
"""记录失败尝试"""
now = datetime.now()
if identifier not in self.failed_attempts:
self.failed_attempts[identifier] = []
# 清理过期的失败记录
self.failed_attempts[identifier] = [
attempt for attempt in self.failed_attempts[identifier]
if now - attempt < self.attempt_window
]
# 添加新的失败记录
self.failed_attempts[identifier].append(now)
# 如果失败次数超过限制,锁定账户
if len(self.failed_attempts[identifier]) >= self.max_attempts:
self.locked_accounts[identifier] = now
logger.warning(f"账户因多次登录失败被锁定: {identifier}")
async def _get_user_info(self, identifier: str) -> Optional[Dict[str, Any]]:
"""获取用户信息(从数据库或其他存储)"""
# 这里应该从数据库查询用户信息
# 示例实现
users_db = {
"admin": {
"user_id": "admin",
"password_hash": "hashed_password",
"salt": "random_salt",
"permissions": ["admin", "read", "write"],
"mfa_enabled": True,
"mfa_methods": ["totp", "sms"]
}
}
return users_db.get(identifier)
async def _verify_password(self, password: str, stored_hash: str, salt: str) -> bool:
"""验证密码"""
# 使用PBKDF2进行密码验证
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt.encode(),
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
return hmac.compare_digest(key.decode(), stored_hash)
async def _create_session(self, user_id: str) -> str:
"""创建会话"""
session_id = secrets.token_urlsafe(32)
# 这里应该将会话存储到数据库或缓存
return session_id
async def validate_session(self, session_id: str) -> AuthResult:
"""验证会话"""
# 从存储中验证会话
# 示例实现
return AuthResult(status=AuthStatus.SUCCESS)
async def revoke_session(self, session_id: str) -> bool:
"""撤销会话"""
# 从存储中删除会话
return True
class JWTAuthProvider(AuthProvider):
"""JWT认证提供者"""
def __init__(self, secret_key: str, algorithm: str = "HS256"):
self.secret_key = secret_key
self.algorithm = algorithm
self.token_expiry = timedelta(hours=24)
async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
"""JWT认证"""
if credentials.method != AuthMethod.JWT_TOKEN:
return AuthResult(
status=AuthStatus.FAILED,
error_message="不支持的认证方式"
)
token = credentials.secret
try:
# 解码JWT令牌
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm]
)
# 检查令牌是否过期
exp = payload.get("exp")
if exp and datetime.fromtimestamp(exp) < datetime.now():
return AuthResult(
status=AuthStatus.EXPIRED,
error_message="令牌已过期"
)
return AuthResult(
status=AuthStatus.SUCCESS,
user_id=payload.get("user_id"),
session_id=payload.get("session_id"),
permissions=payload.get("permissions", []),
expires_at=datetime.fromtimestamp(exp) if exp else None
)
except jwt.ExpiredSignatureError:
return AuthResult(
status=AuthStatus.EXPIRED,
error_message="令牌已过期"
)
except jwt.InvalidTokenError as e:
return AuthResult(
status=AuthStatus.FAILED,
error_message=f"无效令牌: {str(e)}"
)
async def create_token(self, user_id: str, permissions: List[str]) -> str:
"""创建JWT令牌"""
now = datetime.now()
payload = {
"user_id": user_id,
"permissions": permissions,
"iat": now.timestamp(),
"exp": (now + self.token_expiry).timestamp(),
"session_id": secrets.token_urlsafe(16)
}
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
async def validate_session(self, session_id: str) -> AuthResult:
"""验证会话(JWT是无状态的,这里可以检查黑名单)"""
# 检查令牌黑名单
return AuthResult(status=AuthStatus.SUCCESS)
async def revoke_session(self, session_id: str) -> bool:
"""撤销会话(将令牌加入黑名单)"""
# 将令牌加入黑名单
return True
class MFAProvider:
"""多因素认证提供者"""
def __init__(self):
self.pending_verifications: Dict[str, Dict[str, Any]] = {}
async def setup_totp(self, user_id: str) -> Tuple[str, str]:
"""设置TOTP"""
secret = pyotp.random_base32()
totp = pyotp.TOTP(secret)
# 生成二维码URI
qr_uri = totp.provisioning_uri(
name=user_id,
issuer_name="MCP Server"
)
return secret, qr_uri
async def verify_totp(self, user_id: str, secret: str, token: str) -> bool:
"""验证TOTP令牌"""
totp = pyotp.TOTP(secret)
return totp.verify(token, valid_window=1)
async def send_sms_code(self, user_id: str, phone_number: str) -> str:
"""发送短信验证码"""
code = secrets.randbelow(1000000)
code_str = f"{code:06d}"
# 存储待验证的代码
self.pending_verifications[user_id] = {
"type": "sms",
"code": code_str,
"phone": phone_number,
"expires_at": datetime.now() + timedelta(minutes=5)
}
# 这里应该调用短信服务发送验证码
logger.info(f"向 {phone_number} 发送验证码: {code_str}")
return code_str
async def verify_sms_code(self, user_id: str, code: str) -> bool:
"""验证短信代码"""
verification = self.pending_verifications.get(user_id)
if not verification:
return False
if verification["type"] != "sms":
return False
if datetime.now() > verification["expires_at"]:
del self.pending_verifications[user_id]
return False
if verification["code"] != code:
return False
# 清理已验证的代码
del self.pending_verifications[user_id]
return True
class AuthenticationManager:
"""认证管理器"""
def __init__(self):
self.providers: Dict[AuthMethod, AuthProvider] = {}
self.mfa_provider = MFAProvider()
self.active_sessions: Dict[str, Dict[str, Any]] = {}
def register_provider(self, method: AuthMethod, provider: AuthProvider):
"""注册认证提供者"""
self.providers[method] = provider
logger.info(f"已注册认证提供者: {method.value}")
async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
"""执行认证"""
provider = self.providers.get(credentials.method)
if not provider:
return AuthResult(
status=AuthStatus.FAILED,
error_message=f"不支持的认证方式: {credentials.method.value}"
)
try:
result = await provider.authenticate(credentials)
# 如果认证成功,记录会话
if result.status == AuthStatus.SUCCESS and result.session_id:
self.active_sessions[result.session_id] = {
"user_id": result.user_id,
"permissions": result.permissions,
"created_at": datetime.now(),
"expires_at": result.expires_at,
"last_accessed": datetime.now()
}
return result
except Exception as e:
logger.error(f"认证过程中发生错误: {e}")
return AuthResult(
status=AuthStatus.FAILED,
error_message="认证服务异常"
)
async def verify_mfa(self,
user_id: str,
method: str,
token: str,
secret: Optional[str] = None) -> AuthResult:
"""验证多因素认证"""
try:
if method == "totp" and secret:
verified = await self.mfa_provider.verify_totp(user_id, secret, token)
elif method == "sms":
verified = await self.mfa_provider.verify_sms_code(user_id, token)
else:
return AuthResult(
status=AuthStatus.FAILED,
error_message="不支持的MFA方式"
)
if verified:
# 创建完整的会话
session_id = secrets.token_urlsafe(32)
expires_at = datetime.now() + timedelta(hours=24)
self.active_sessions[session_id] = {
"user_id": user_id,
"permissions": [], # 从用户信息获取
"created_at": datetime.now(),
"expires_at": expires_at,
"last_accessed": datetime.now(),
"mfa_verified": True
}
return AuthResult(
status=AuthStatus.SUCCESS,
user_id=user_id,
session_id=session_id,
expires_at=expires_at
)
else:
return AuthResult(
status=AuthStatus.FAILED,
error_message="MFA验证失败"
)
except Exception as e:
logger.error(f"MFA验证过程中发生错误: {e}")
return AuthResult(
status=AuthStatus.FAILED,
error_message="MFA验证服务异常"
)
async def validate_session(self, session_id: str) -> AuthResult:
"""验证会话"""
session = self.active_sessions.get(session_id)
if not session:
return AuthResult(
status=AuthStatus.FAILED,
error_message="会话不存在"
)
# 检查会话是否过期
if datetime.now() > session["expires_at"]:
del self.active_sessions[session_id]
return AuthResult(
status=AuthStatus.EXPIRED,
error_message="会话已过期"
)
# 更新最后访问时间
session["last_accessed"] = datetime.now()
return AuthResult(
status=AuthStatus.SUCCESS,
user_id=session["user_id"],
session_id=session_id,
permissions=session["permissions"],
expires_at=session["expires_at"]
)
async def revoke_session(self, session_id: str) -> bool:
"""撤销会话"""
if session_id in self.active_sessions:
del self.active_sessions[session_id]
return True
return False
async def cleanup_expired_sessions(self):
"""清理过期会话"""
now = datetime.now()
expired_sessions = [
session_id for session_id, session in self.active_sessions.items()
if now > session["expires_at"]
]
for session_id in expired_sessions:
del self.active_sessions[session_id]
if expired_sessions:
logger.info(f"清理了 {len(expired_sessions)} 个过期会话")
🔐 权限控制系统
🏛️ RBAC权限模型
python
"""
authorization.py - 权限控制系统
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Set, Any, Optional, Union
from dataclasses import dataclass, field
from enum import Enum
import asyncio
import json
from datetime import datetime, timedelta
from loguru import logger
class PermissionType(Enum):
"""权限类型"""
READ = "read"
WRITE = "write"
DELETE = "delete"
EXECUTE = "execute"
ADMIN = "admin"
CREATE = "create"
UPDATE = "update"
MANAGE = "manage"
class ResourceType(Enum):
"""资源类型"""
STREAM = "stream"
FILE = "file"
DATABASE = "database"
API = "api"
TEMPLATE = "template"
USER = "user"
ROLE = "role"
SYSTEM = "system"
@dataclass
class Permission:
"""权限"""
name: str
resource_type: ResourceType
permission_type: PermissionType
resource_pattern: str = "*" # 资源匹配模式
conditions: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
def matches_resource(self, resource_id: str) -> bool:
"""检查权限是否匹配资源"""
import fnmatch
return fnmatch.fnmatch(resource_id, self.resource_pattern)
@dataclass
class Role:
"""角色"""
name: str
description: str
permissions: List[Permission] = field(default_factory=list)
inherits_from: List[str] = field(default_factory=list) # 继承的角色
metadata: Dict[str, Any] = field(default_factory=dict)
is_active: bool = True
created_at: datetime = field(default_factory=datetime.now)
@dataclass
class User:
"""用户"""
user_id: str
username: str
email: str
roles: List[str] = field(default_factory=list)
direct_permissions: List[Permission] = field(default_factory=list)
attributes: Dict[str, Any] = field(default_factory=dict) # 用于ABAC
is_active: bool = True
created_at: datetime = field(default_factory=datetime.now)
last_login: Optional[datetime] = None
@dataclass
class AuthorizationContext:
"""授权上下文"""
user_id: str
resource_type: ResourceType
resource_id: str
permission_type: PermissionType
request_context: Dict[str, Any] = field(default_factory=dict)
timestamp: datetime = field(default_factory=datetime.now)
@dataclass
class AuthorizationResult:
"""授权结果"""
granted: bool
reason: str
matched_permissions: List[Permission] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
class AuthorizationPolicy(ABC):
"""授权策略抽象基类"""
@abstractmethod
async def evaluate(self,
context: AuthorizationContext,
user: User,
permissions: List[Permission]) -> AuthorizationResult:
"""评估授权"""
pass
class RBACPolicy(AuthorizationPolicy):
"""基于角色的访问控制策略"""
async def evaluate(self,
context: AuthorizationContext,
user: User,
permissions: List[Permission]) -> AuthorizationResult:
"""评估RBAC授权"""
matched_permissions = []
# 检查用户的直接权限和角色权限
for permission in permissions:
if (permission.resource_type == context.resource_type and
permission.permission_type == context.permission_type and
permission.matches_resource(context.resource_id)):
matched_permissions.append(permission)
if matched_permissions:
return AuthorizationResult(
granted=True,
reason="RBAC权限匹配",
matched_permissions=matched_permissions
)
else:
return AuthorizationResult(
granted=False,
reason="没有匹配的RBAC权限"
)
class ABACPolicy(AuthorizationPolicy):
"""基于属性的访问控制策略"""
def __init__(self):
self.policy_rules: List[Dict[str, Any]] = []
def add_rule(self, rule: Dict[str, Any]):
"""添加策略规则"""
self.policy_rules.append(rule)
async def evaluate(self,
context: AuthorizationContext,
user: User,
permissions: List[Permission]) -> AuthorizationResult:
"""评估ABAC授权"""
for rule in self.policy_rules:
if await self._evaluate_rule(rule, context, user):
return AuthorizationResult(
granted=True,
reason=f"ABAC规则匹配: {rule.get('name', 'unnamed')}"
)
return AuthorizationResult(
granted=False,
reason="没有匹配的ABAC规则"
)
async def _evaluate_rule(self,
rule: Dict[str, Any],
context: AuthorizationContext,
user: User) -> bool:
"""评估单个规则"""
conditions = rule.get("conditions", {})
# 检查用户属性条件
user_conditions = conditions.get("user", {})
for attr, expected_value in user_conditions.items():
user_value = user.attributes.get(attr)
if not self._match_condition(user_value, expected_value):
return False
# 检查资源条件
resource_conditions = conditions.get("resource", {})
for attr, expected_value in resource_conditions.items():
if attr == "type":
if context.resource_type.value != expected_value:
return False
elif attr == "id":
if not self._match_pattern(context.resource_id, expected_value):
return False
# 检查环境条件
environment_conditions = conditions.get("environment", {})
for attr, expected_value in environment_conditions.items():
if attr == "time_range":
if not self._in_time_range(context.timestamp, expected_value):
return False
elif attr == "ip_range":
client_ip = context.request_context.get("client_ip")
if not self._in_ip_range(client_ip, expected_value):
return False
return True
def _match_condition(self, actual_value: Any, expected_value: Any) -> bool:
"""匹配条件值"""
if isinstance(expected_value, list):
return actual_value in expected_value
elif isinstance(expected_value, dict):
operator = expected_value.get("op")
value = expected_value.get("value")
if operator == "eq":
return actual_value == value
elif operator == "ne":
return actual_value != value
elif operator == "gt":
return actual_value > value
elif operator == "ge":
return actual_value >= value
elif operator == "lt":
return actual_value < value
elif operator == "le":
return actual_value <= value
elif operator == "in":
return actual_value in value
elif operator == "contains":
return value in str(actual_value)
else:
return actual_value == expected_value
def _match_pattern(self, value: str, pattern: str) -> bool:
"""匹配模式"""
import fnmatch
return fnmatch.fnmatch(value, pattern)
def _in_time_range(self, timestamp: datetime, time_range: Dict[str, str]) -> bool:
"""检查时间范围"""
start_time = datetime.fromisoformat(time_range.get("start", "1970-01-01T00:00:00"))
end_time = datetime.fromisoformat(time_range.get("end", "2099-12-31T23:59:59"))
return start_time <= timestamp <= end_time
def _in_ip_range(self, ip: Optional[str], ip_range: str) -> bool:
"""检查IP范围"""
if not ip:
return False
# 简单实现,实际应该使用ipaddress模块
import ipaddress
try:
network = ipaddress.IPv4Network(ip_range, strict=False)
return ipaddress.IPv4Address(ip) in network
except:
return False
class PermissionManager:
"""权限管理器"""
def __init__(self):
self.users: Dict[str, User] = {}
self.roles: Dict[str, Role] = {}
self.policies: List[AuthorizationPolicy] = []
self.permission_cache: Dict[str, Dict[str, Any]] = {}
self.cache_ttl = timedelta(minutes=30)
def add_policy(self, policy: AuthorizationPolicy):
"""添加授权策略"""
self.policies.append(policy)
async def create_user(self, user: User):
"""创建用户"""
self.users[user.user_id] = user
logger.info(f"创建用户: {user.username}")
async def create_role(self, role: Role):
"""创建角色"""
self.roles[role.name] = role
logger.info(f"创建角色: {role.name}")
async def assign_role_to_user(self, user_id: str, role_name: str):
"""为用户分配角色"""
user = self.users.get(user_id)
if not user:
raise ValueError(f"用户不存在: {user_id}")
if role_name not in self.roles:
raise ValueError(f"角色不存在: {role_name}")
if role_name not in user.roles:
user.roles.append(role_name)
self._invalidate_user_cache(user_id)
logger.info(f"为用户 {user_id} 分配角色 {role_name}")
async def revoke_role_from_user(self, user_id: str, role_name: str):
"""撤销用户角色"""
user = self.users.get(user_id)
if user and role_name in user.roles:
user.roles.remove(role_name)
self._invalidate_user_cache(user_id)
logger.info(f"撤销用户 {user_id} 的角色 {role_name}")
async def get_user_permissions(self, user_id: str) -> List[Permission]:
"""获取用户的所有权限"""
# 检查缓存
cache_key = f"permissions:{user_id}"
cached = self.permission_cache.get(cache_key)
if cached and datetime.now() - cached["timestamp"] < self.cache_ttl:
return cached["permissions"]
user = self.users.get(user_id)
if not user:
return []
permissions = []
# 添加直接权限
permissions.extend(user.direct_permissions)
# 添加角色权限
visited_roles = set()
await self._collect_role_permissions(user.roles, permissions, visited_roles)
# 缓存结果
self.permission_cache[cache_key] = {
"permissions": permissions,
"timestamp": datetime.now()
}
return permissions
async def _collect_role_permissions(self,
role_names: List[str],
permissions: List[Permission],
visited_roles: Set[str]):
"""递归收集角色权限"""
for role_name in role_names:
if role_name in visited_roles:
continue # 避免循环依赖
visited_roles.add(role_name)
role = self.roles.get(role_name)
if not role or not role.is_active:
continue
# 添加角色的直接权限
permissions.extend(role.permissions)
# 递归处理继承的角色
if role.inherits_from:
await self._collect_role_permissions(
role.inherits_from,
permissions,
visited_roles
)
async def check_permission(self, context: AuthorizationContext) -> AuthorizationResult:
"""检查权限"""
user = self.users.get(context.user_id)
if not user:
return AuthorizationResult(
granted=False,
reason="用户不存在"
)
if not user.is_active:
return AuthorizationResult(
granted=False,
reason="用户已被禁用"
)
# 获取用户权限
permissions = await self.get_user_permissions(context.user_id)
# 遍历所有策略进行评估
for policy in self.policies:
result = await policy.evaluate(context, user, permissions)
if result.granted:
# 记录访问日志
await self._log_access(context, result, True)
return result
# 如果所有策略都拒绝访问
result = AuthorizationResult(
granted=False,
reason="所有授权策略都拒绝访问"
)
# 记录拒绝访问日志
await self._log_access(context, result, False)
return result
async def _log_access(self,
context: AuthorizationContext,
result: AuthorizationResult,
granted: bool):
"""记录访问日志"""
log_entry = {
"timestamp": context.timestamp.isoformat(),
"user_id": context.user_id,
"resource_type": context.resource_type.value,
"resource_id": context.resource_id,
"permission_type": context.permission_type.value,
"granted": granted,
"reason": result.reason,
"client_ip": context.request_context.get("client_ip"),
"user_agent": context.request_context.get("user_agent")
}
if granted:
logger.info(f"访问授权: {json.dumps(log_entry, ensure_ascii=False)}")
else:
logger.warning(f"访问拒绝: {json.dumps(log_entry, ensure_ascii=False)}")
def _invalidate_user_cache(self, user_id: str):
"""失效用户缓存"""
cache_key = f"permissions:{user_id}"
self.permission_cache.pop(cache_key, None)
async def cleanup_cache(self):
"""清理过期缓存"""
now = datetime.now()
expired_keys = [
key for key, value in self.permission_cache.items()
if now - value["timestamp"] > self.cache_ttl
]
for key in expired_keys:
del self.permission_cache[key]
if expired_keys:
logger.debug(f"清理了 {len(expired_keys)} 个过期缓存项")
🔒 数据加密和安全通信
🛡️ 端到端加密
python
"""
encryption.py - 数据加密和安全通信
"""
from typing import Dict, Any, Optional, Tuple, Union
import asyncio
import secrets
import hashlib
import hmac
from datetime import datetime, timedelta
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
import base64
import ssl
import aiohttp
from loguru import logger
class EncryptionType(Enum):
"""加密类型"""
SYMMETRIC = "symmetric"
ASYMMETRIC = "asymmetric"
HYBRID = "hybrid"
class CryptoManager:
"""加密管理器"""
def __init__(self):
self.key_store: Dict[str, Any] = {}
self.session_keys: Dict[str, bytes] = {}
# 生成服务器主密钥
self.master_key = Fernet.generate_key()
self.fernet = Fernet(self.master_key)
# 生成RSA密钥对
self.private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048
)
self.public_key = self.private_key.public_key()
def generate_session_key(self, session_id: str) -> bytes:
"""生成会话密钥"""
key = secrets.token_bytes(32)
self.session_keys[session_id] = key
logger.debug(f"为会话 {session_id} 生成了会话密钥")
return key
def get_session_key(self, session_id: str) -> Optional[bytes]:
"""获取会话密钥"""
return self.session_keys.get(session_id)
def encrypt_symmetric(self, data: Union[str, bytes], key: Optional[bytes] = None) -> bytes:
"""对称加密"""
if isinstance(data, str):
data = data.encode('utf-8')
if key is None:
return self.fernet.encrypt(data)
else:
f = Fernet(base64.urlsafe_b64encode(key[:32]))
return f.encrypt(data)
def decrypt_symmetric(self, encrypted_data: bytes, key: Optional[bytes] = None) -> bytes:
"""对称解密"""
if key is None:
return self.fernet.decrypt(encrypted_data)
else:
f = Fernet(base64.urlsafe_b64encode(key[:32]))
return f.decrypt(encrypted_data)
def encrypt_asymmetric(self, data: Union[str, bytes], public_key: Optional[rsa.RSAPublicKey] = None) -> bytes:
"""非对称加密"""
if isinstance(data, str):
data = data.encode('utf-8')
if public_key is None:
public_key = self.public_key
return public_key.encrypt(
data,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
def decrypt_asymmetric(self, encrypted_data: bytes) -> bytes:
"""非对称解密"""
return self.private_key.decrypt(
encrypted_data,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
def sign_data(self, data: Union[str, bytes]) -> bytes:
"""数据签名"""
if isinstance(data, str):
data = data.encode('utf-8')
return self.private_key.sign(
data,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
def verify_signature(self, data: Union[str, bytes], signature: bytes, public_key: Optional[rsa.RSAPublicKey] = None) -> bool:
"""验证签名"""
if isinstance(data, str):
data = data.encode('utf-8')
if public_key is None:
public_key = self.public_key
try:
public_key.verify(
signature,
data,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return True
except Exception:
return False
def get_public_key_pem(self) -> str:
"""获取公钥PEM格式"""
pem = self.public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return pem.decode('utf-8')
def hash_data(self, data: Union[str, bytes], algorithm: str = "sha256") -> str:
"""数据哈希"""
if isinstance(data, str):
data = data.encode('utf-8')
if algorithm == "sha256":
return hashlib.sha256(data).hexdigest()
elif algorithm == "sha512":
return hashlib.sha512(data).hexdigest()
elif algorithm == "md5":
return hashlib.md5(data).hexdigest()
else:
raise ValueError(f"不支持的哈希算法: {algorithm}")
def generate_hmac(self, data: Union[str, bytes], key: bytes) -> str:
"""生成HMAC"""
if isinstance(data, str):
data = data.encode('utf-8')
return hmac.new(key, data, hashlib.sha256).hexdigest()
def verify_hmac(self, data: Union[str, bytes], key: bytes, expected_hmac: str) -> bool:
"""验证HMAC"""
calculated_hmac = self.generate_hmac(data, key)
return hmac.compare_digest(calculated_hmac, expected_hmac)
class SecureTransport:
"""安全传输"""
def __init__(self, crypto_manager: CryptoManager):
self.crypto_manager = crypto_manager
self.ssl_context = self._create_ssl_context()
def _create_ssl_context(self) -> ssl.SSLContext:
"""创建SSL上下文"""
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
context.check_hostname = False
context.verify_mode = ssl.CERT_REQUIRED
return context
async def secure_request(self,
url: str,
method: str = "GET",
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
session_id: Optional[str] = None) -> Dict[str, Any]:
"""安全HTTP请求"""
headers = headers or {}
# 如果有会话ID,使用会话密钥加密数据
if session_id and data:
session_key = self.crypto_manager.get_session_key(session_id)
if session_key:
# 加密请求数据
encrypted_data = self.crypto_manager.encrypt_symmetric(
json.dumps(data),
session_key
)
data = {"encrypted": base64.b64encode(encrypted_data).decode()}
headers["X-Encrypted"] = "true"
# 添加请求签名
if data:
request_body = json.dumps(data)
signature = self.crypto_manager.sign_data(request_body)
headers["X-Signature"] = base64.b64encode(signature).decode()
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(ssl=self.ssl_context)
) as session:
async with session.request(
method,
url,
json=data,
headers=headers
) as response:
response_data = await response.json()
# 验证响应签名
if "X-Signature" in response.headers:
response_body = await response.text()
signature = base64.b64decode(response.headers["X-Signature"])
if not self.crypto_manager.verify_signature(response_body, signature):
raise ValueError("响应签名验证失败")
# 如果响应是加密的,进行解密
if response.headers.get("X-Encrypted") == "true" and session_id:
session_key = self.crypto_manager.get_session_key(session_id)
if session_key and "encrypted" in response_data:
encrypted_data = base64.b64decode(response_data["encrypted"])
decrypted_data = self.crypto_manager.decrypt_symmetric(
encrypted_data,
session_key
)
response_data = json.loads(decrypted_data.decode())
return response_data
class SecurityMiddleware:
"""安全中间件"""
def __init__(self,
crypto_manager: CryptoManager,
auth_manager: AuthenticationManager,
permission_manager: PermissionManager):
self.crypto_manager = crypto_manager
self.auth_manager = auth_manager
self.permission_manager = permission_manager
self.rate_limiter = {}
self.blocked_ips = set()
async def process_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""处理入站请求"""
client_ip = request.get("client_ip")
# IP黑名单检查
if client_ip in self.blocked_ips:
raise PermissionError("IP地址已被阻止")
# 速率限制
if not await self._check_rate_limit(client_ip):
raise PermissionError("请求频率过高")
# 验证请求签名
if "signature" in request.get("headers", {}):
await self._verify_request_signature(request)
# 解密请求数据
if request.get("encrypted"):
request = await self._decrypt_request(request)
# 身份认证
auth_result = await self._authenticate_request(request)
if auth_result.status != AuthStatus.SUCCESS:
raise PermissionError(f"认证失败: {auth_result.error_message}")
request["user_id"] = auth_result.user_id
request["session_id"] = auth_result.session_id
return request
async def _check_rate_limit(self, client_ip: str) -> bool:
"""检查速率限制"""
now = datetime.now()
window = timedelta(minutes=1)
max_requests = 60
if client_ip not in self.rate_limiter:
self.rate_limiter[client_ip] = []
# 清理过期记录
self.rate_limiter[client_ip] = [
timestamp for timestamp in self.rate_limiter[client_ip]
if now - timestamp < window
]
# 检查当前请求数
if len(self.rate_limiter[client_ip]) >= max_requests:
logger.warning(f"IP {client_ip} 触发速率限制")
return False
# 记录当前请求
self.rate_limiter[client_ip].append(now)
return True
async def _verify_request_signature(self, request: Dict[str, Any]):
"""验证请求签名"""
signature_b64 = request["headers"]["signature"]
signature = base64.b64decode(signature_b64)
request_body = json.dumps(request.get("data", {}))
if not self.crypto_manager.verify_signature(request_body, signature):
raise ValueError("请求签名验证失败")
async def _decrypt_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""解密请求"""
session_id = request.get("session_id")
if not session_id:
raise ValueError("加密请求需要会话ID")
session_key = self.crypto_manager.get_session_key(session_id)
if not session_key:
raise ValueError("会话密钥不存在")
encrypted_data = base64.b64decode(request["encrypted"])
decrypted_data = self.crypto_manager.decrypt_symmetric(
encrypted_data,
session_key
)
decrypted_request = json.loads(decrypted_data.decode())
return decrypted_request
async def _authenticate_request(self, request: Dict[str, Any]) -> AuthResult:
"""认证请求"""
auth_header = request.get("headers", {}).get("authorization")
if not auth_header:
return AuthResult(
status=AuthStatus.FAILED,
error_message="缺少认证头"
)
# 解析认证头
if auth_header.startswith("Bearer "):
token = auth_header[7:]
credentials = AuthCredentials(
method=AuthMethod.JWT_TOKEN,
identifier="",
secret=token
)
elif auth_header.startswith("ApiKey "):
api_key = auth_header[7:]
credentials = AuthCredentials(
method=AuthMethod.API_KEY,
identifier="",
secret=api_key
)
else:
return AuthResult(
status=AuthStatus.FAILED,
error_message="不支持的认证方式"
)
return await self.auth_manager.authenticate(credentials)
🎯 本节小结
通过这一小节,你已经构建了一个企业级的安全和权限控制系统:
✅ 多因素认证:支持密码、JWT、API密钥、TOTP、短信等多种认证方式
✅ RBAC权限模型:基于角色的访问控制,支持角色继承
✅ ABAC策略引擎:基于属性的访问控制,支持复杂条件判断
✅ 数据加密保护:对称/非对称加密、数字签名、HMAC验证
✅ 安全传输:SSL/TLS加密通信、请求签名验证
✅ 安全中间件:速率限制、IP黑名单、请求安全检查
🚀 使用示例
python
# 初始化安全系统
crypto_manager = CryptoManager()
auth_manager = AuthenticationManager()
permission_manager = PermissionManager()
# 注册认证提供者
auth_manager.register_provider(
AuthMethod.PASSWORD,
PasswordAuthProvider()
)
auth_manager.register_provider(
AuthMethod.JWT_TOKEN,
JWTAuthProvider("your-secret-key")
)
# 添加授权策略
permission_manager.add_policy(RBACPolicy())
permission_manager.add_policy(ABACPolicy())
# 创建用户和角色
await permission_manager.create_role(Role(
name="admin",
description="系统管理员",
permissions=[
Permission("admin", ResourceType.SYSTEM, PermissionType.ADMIN)
]
))
await permission_manager.create_user(User(
user_id="user1",
username="admin",
email="admin@example.com",
roles=["admin"]
))
# 权限检查
context = AuthorizationContext(
user_id="user1",
resource_type=ResourceType.STREAM,
resource_id="stream123",
permission_type=PermissionType.READ
)
result = await permission_manager.check_permission(context)
if result.granted:
print("访问被授权")
else:
print(f"访问被拒绝: {result.reason}")
🔒 安全最佳实践
- 最小权限原则:只授予完成任务所需的最小权限
- 深度防御:多层安全控制,不依赖单一安全机制
- 安全审计:记录所有安全相关操作和访问尝试
- 定期更新:及时更新密钥、证书和安全策略
- 异常监控:监控异常访问模式和潜在威胁
现在你的MCP服务器具备了企业级的安全防护能力!