8.7 安全实现指南 🛡️
"安全不是事后想起来的补丁,而是从设计之初就融入血液的基因。"
想象你的MCP服务器是一座城堡,用户的请求是访客。你需要建立严密的安全防线:城门(身份验证)、哨兵(权限控制)、护城河(输入验证)、瞭望塔(监控预警)。每一道防线都不可或缺。
多重身份验证:守护城门 🏰
渐进式认证体系
typescript
// 🌟 多重身份验证系统
import * as crypto from 'crypto';
import * as jwt from 'jsonwebtoken';
import { promisify } from 'util';
interface AuthenticationContext {
userId: string;
sessionId: string;
ipAddress: string;
userAgent: string;
timestamp: number;
riskScore: number;
authMethods: AuthMethod[];
}
interface AuthMethod {
type: 'password' | 'totp' | 'sms' | 'biometric' | 'hardware_key';
verified: boolean;
timestamp: number;
strength: number; // 1-100
}
interface SecurityPolicy {
requireMFA: boolean;
minAuthMethods: number;
maxFailedAttempts: number;
sessionTimeout: number;
requireReauth: string[]; // 需要重新认证的敏感操作
ipWhitelist?: string[];
geoRestrictions?: {
allowedCountries: string[];
blockSuspiciousLocations: boolean;
};
}
class MultiFactorAuthenticator {
private failedAttempts = new Map<string, number>();
private lockedAccounts = new Map<string, number>();
private activeSessions = new Map<string, AuthenticationContext>();
constructor(
private policy: SecurityPolicy,
private secretKey: string
) {}
/**
* 初始认证 - 第一步验证
*/
async initiateAuthentication(
credentials: {
username: string;
password: string;
ipAddress: string;
userAgent: string;
}
): Promise<AuthenticationResponse> {
const { username, password, ipAddress, userAgent } = credentials;
// 1. 检查账户是否被锁定
if (this.isAccountLocked(username)) {
const lockTime = this.lockedAccounts.get(username) || 0;
const remainingTime = Math.ceil((lockTime - Date.now()) / 1000);
throw new AuthenticationError(
`账户已锁定,请在${remainingTime}秒后重试`,
'ACCOUNT_LOCKED'
);
}
// 2. 风险评估
const riskScore = await this.assessRisk({
username,
ipAddress,
userAgent,
timestamp: Date.now()
});
console.log(`🔍 风险评估:用户${username},风险评分:${riskScore}`);
// 3. 验证密码
const user = await this.getUserByUsername(username);
if (!user || !await this.verifyPassword(password, user.passwordHash)) {
this.recordFailedAttempt(username);
throw new AuthenticationError(
'用户名或密码错误',
'INVALID_CREDENTIALS'
);
}
// 4. 创建认证上下文
const sessionId = this.generateSessionId();
const authContext: AuthenticationContext = {
userId: user.id,
sessionId,
ipAddress,
userAgent,
timestamp: Date.now(),
riskScore,
authMethods: [{
type: 'password',
verified: true,
timestamp: Date.now(),
strength: 60
}]
};
// 5. 根据风险等级决定是否需要额外验证
const requiresAdditionalAuth = this.requiresMultiFactorAuth(riskScore, user);
if (requiresAdditionalAuth) {
// 生成临时令牌,用于后续MFA验证
const tempToken = this.generateTempToken(authContext);
return {
success: false,
requiresMFA: true,
availableMethods: await this.getAvailableAuthMethods(user.id),
tempToken,
message: '需要额外的身份验证'
};
}
// 6. 单因子认证通过,创建会话
const accessToken = this.generateAccessToken(authContext);
this.activeSessions.set(sessionId, authContext);
// 清除失败记录
this.failedAttempts.delete(username);
return {
success: true,
accessToken,
sessionId,
expiresIn: this.policy.sessionTimeout,
message: '认证成功'
};
}
/**
* 多因子认证验证
*/
async verifyMultiFactor(
tempToken: string,
method: AuthMethod['type'],
code: string
): Promise<AuthenticationResponse> {
// 1. 验证临时令牌
const authContext = this.verifyTempToken(tempToken);
if (!authContext) {
throw new AuthenticationError('临时令牌无效或已过期', 'INVALID_TEMP_TOKEN');
}
// 2. 验证MFA代码
const isValidCode = await this.verifyMFACode(authContext.userId, method, code);
if (!isValidCode) {
throw new AuthenticationError('验证码错误', 'INVALID_MFA_CODE');
}
// 3. 更新认证方法
authContext.authMethods.push({
type: method,
verified: true,
timestamp: Date.now(),
strength: this.getMethodStrength(method)
});
// 4. 检查是否满足认证要求
const totalStrength = authContext.authMethods
.reduce((sum, m) => sum + m.strength, 0);
const requiredStrength = this.getRequiredAuthStrength(authContext.riskScore);
if (totalStrength < requiredStrength ||
authContext.authMethods.length < this.policy.minAuthMethods) {
return {
success: false,
requiresMFA: true,
availableMethods: await this.getAvailableAuthMethods(authContext.userId),
tempToken: this.generateTempToken(authContext),
message: `需要更强的身份验证(当前强度:${totalStrength},要求:${requiredStrength})`
};
}
// 5. MFA验证完成,创建会话
const accessToken = this.generateAccessToken(authContext);
this.activeSessions.set(authContext.sessionId, authContext);
console.log(`✅ MFA认证成功:用户${authContext.userId},方法:${authContext.authMethods.map(m => m.type).join(', ')}`);
return {
success: true,
accessToken,
sessionId: authContext.sessionId,
expiresIn: this.policy.sessionTimeout,
message: 'MFA认证成功'
};
}
/**
* 风险评估算法
*/
private async assessRisk(context: {
username: string;
ipAddress: string;
userAgent: string;
timestamp: number;
}): Promise<number> {
let riskScore = 0;
// 1. IP地址风险(30%权重)
const ipRisk = await this.assessIPRisk(context.ipAddress);
riskScore += ipRisk * 0.3;
// 2. 地理位置风险(20%权重)
const geoRisk = await this.assessGeoRisk(context.ipAddress, context.username);
riskScore += geoRisk * 0.2;
// 3. 设备指纹风险(25%权重)
const deviceRisk = await this.assessDeviceRisk(context.userAgent, context.username);
riskScore += deviceRisk * 0.25;
// 4. 时间模式风险(15%权重)
const timeRisk = await this.assessTimeRisk(context.timestamp, context.username);
riskScore += timeRisk * 0.15;
// 5. 历史行为风险(10%权重)
const behaviorRisk = await this.assessBehaviorRisk(context.username);
riskScore += behaviorRisk * 0.1;
return Math.min(100, Math.max(0, riskScore));
}
private async assessIPRisk(ipAddress: string): Promise<number> {
// IP信誉检查
const ipInfo = await this.getIPInfo(ipAddress);
let risk = 0;
// 检查是否在黑名单
if (ipInfo.isBlacklisted) risk += 80;
// 检查是否为代理/VPN
if (ipInfo.isProxy || ipInfo.isVPN) risk += 40;
// 检查是否为Tor节点
if (ipInfo.isTor) risk += 60;
// 检查地理位置异常
if (ipInfo.isHighRiskCountry) risk += 30;
return risk;
}
/**
* 生成安全的会话令牌
*/
private generateAccessToken(context: AuthenticationContext): string {
const payload = {
userId: context.userId,
sessionId: context.sessionId,
authMethods: context.authMethods.map(m => m.type),
authStrength: context.authMethods.reduce((sum, m) => sum + m.strength, 0),
riskScore: context.riskScore,
iat: Math.floor(Date.now() / 1000),
exp: Math.floor(Date.now() / 1000) + this.policy.sessionTimeout
};
return jwt.sign(payload, this.secretKey, {
algorithm: 'HS256',
issuer: 'mcp-server',
audience: context.userId
});
}
/**
* 验证访问令牌
*/
async validateAccessToken(token: string): Promise<AuthenticationContext | null> {
try {
const payload = jwt.verify(token, this.secretKey) as any;
// 检查会话是否仍然有效
const session = this.activeSessions.get(payload.sessionId);
if (!session) {
console.warn(`⚠️ 会话不存在:${payload.sessionId}`);
return null;
}
// 检查会话是否过期
if (Date.now() - session.timestamp > this.policy.sessionTimeout * 1000) {
console.warn(`⚠️ 会话已过期:${payload.sessionId}`);
this.activeSessions.delete(payload.sessionId);
return null;
}
return session;
} catch (error) {
console.warn('⚠️ 令牌验证失败:', error.message);
return null;
}
}
/**
* 检查操作权限
*/
async checkOperationPermission(
token: string,
operation: string,
resource?: string
): Promise<boolean> {
const context = await this.validateAccessToken(token);
if (!context) return false;
// 检查是否需要重新认证
if (this.policy.requireReauth.includes(operation)) {
const timeSinceAuth = Date.now() - Math.max(
...context.authMethods.map(m => m.timestamp)
);
// 敏感操作要求30分钟内重新认证
if (timeSinceAuth > 30 * 60 * 1000) {
throw new AuthenticationError(
'敏感操作需要重新认证',
'REAUTHENTICATION_REQUIRED'
);
}
}
// 基于角色的权限检查
const userPermissions = await this.getUserPermissions(context.userId);
return this.hasPermission(userPermissions, operation, resource);
}
private generateSessionId(): string {
return crypto.randomBytes(32).toString('hex');
}
private generateTempToken(context: AuthenticationContext): string {
const payload = {
sessionId: context.sessionId,
userId: context.userId,
type: 'temp_mfa',
exp: Math.floor(Date.now() / 1000) + 300 // 5分钟有效期
};
return jwt.sign(payload, this.secretKey);
}
private verifyTempToken(token: string): AuthenticationContext | null {
try {
const payload = jwt.verify(token, this.secretKey) as any;
if (payload.type !== 'temp_mfa') {
return null;
}
// 从临时存储中获取认证上下文
return this.getTempAuthContext(payload.sessionId);
} catch (error) {
return null;
}
}
/**
* 账户安全监控
*/
startSecurityMonitoring() {
// 定期清理过期会话
setInterval(() => {
this.cleanupExpiredSessions();
}, 60000); // 每分钟清理一次
// 定期解锁账户
setInterval(() => {
this.unlockExpiredAccounts();
}, 300000); // 每5分钟检查一次
// 监控异常活动
setInterval(() => {
this.detectAnomalousActivity();
}, 30000); // 每30秒检查一次
}
private cleanupExpiredSessions() {
const now = Date.now();
const expiredSessions: string[] = [];
for (const [sessionId, context] of this.activeSessions) {
if (now - context.timestamp > this.policy.sessionTimeout * 1000) {
expiredSessions.push(sessionId);
}
}
expiredSessions.forEach(sessionId => {
this.activeSessions.delete(sessionId);
});
if (expiredSessions.length > 0) {
console.log(`🗑️ 清理了${expiredSessions.length}个过期会话`);
}
}
/**
* 获取安全统计信息
*/
getSecurityStats(): SecurityStats {
const now = Date.now();
const activeSessions = Array.from(this.activeSessions.values());
return {
activeSessionsCount: activeSessions.length,
lockedAccountsCount: this.lockedAccounts.size,
averageRiskScore: activeSessions.length > 0
? activeSessions.reduce((sum, s) => sum + s.riskScore, 0) / activeSessions.length
: 0,
highRiskSessionsCount: activeSessions.filter(s => s.riskScore > 70).length,
mfaUsageRate: activeSessions.filter(s => s.authMethods.length > 1).length / Math.max(1, activeSessions.length),
recentFailedAttempts: Array.from(this.failedAttempts.values()).reduce((sum, count) => sum + count, 0)
};
}
}
// 异常检测系统
class SecurityAnomalyDetector {
private baselineProfiles = new Map<string, UserBaselineProfile>();
private recentActivities: SecurityEvent[] = [];
/**
* 检测用户行为异常
*/
async detectUserAnomalies(userId: string, event: SecurityEvent): Promise<AnomalyAlert[]> {
const alerts: AnomalyAlert[] = [];
// 获取用户基线档案
let profile = this.baselineProfiles.get(userId);
if (!profile) {
profile = await this.buildBaselineProfile(userId);
this.baselineProfiles.set(userId, profile);
}
// 1. 地理位置异常检测
const geoAnomaly = this.detectGeographicAnomaly(event, profile);
if (geoAnomaly) alerts.push(geoAnomaly);
// 2. 时间模式异常检测
const timeAnomaly = this.detectTimePatternAnomaly(event, profile);
if (timeAnomaly) alerts.push(timeAnomaly);
// 3. 访问频率异常检测
const frequencyAnomaly = this.detectFrequencyAnomaly(event, profile);
if (frequencyAnomaly) alerts.push(frequencyAnomaly);
// 4. 设备指纹异常检测
const deviceAnomaly = this.detectDeviceAnomaly(event, profile);
if (deviceAnomaly) alerts.push(deviceAnomaly);
// 记录事件用于后续分析
this.recentActivities.push({
...event,
timestamp: Date.now()
});
// 保持最近1000个事件
if (this.recentActivities.length > 1000) {
this.recentActivities = this.recentActivities.slice(-1000);
}
return alerts;
}
private detectGeographicAnomaly(
event: SecurityEvent,
profile: UserBaselineProfile
): AnomalyAlert | null {
// 计算与常用地理位置的距离
const currentLocation = this.getLocationFromIP(event.ipAddress);
if (!currentLocation) return null;
const minDistance = Math.min(
...profile.commonLocations.map(loc =>
this.calculateDistance(currentLocation, loc)
)
);
// 如果距离超过1000公里,标记为异常
if (minDistance > 1000) {
return {
type: 'GEOGRAPHIC_ANOMALY',
severity: minDistance > 5000 ? 'HIGH' : 'MEDIUM',
message: `检测到从${currentLocation.country}的异常登录,距离常用位置${minDistance.toFixed(0)}公里`,
details: {
currentLocation,
distance: minDistance,
commonLocations: profile.commonLocations
}
};
}
return null;
}
private detectTimePatternAnomaly(
event: SecurityEvent,
profile: UserBaselineProfile
): AnomalyAlert | null {
const eventHour = new Date(event.timestamp).getHours();
const dayOfWeek = new Date(event.timestamp).getDay();
// 检查时间是否在用户常用时间范围内
const isCommonHour = profile.commonHours.includes(eventHour);
const isCommonDay = profile.commonDays.includes(dayOfWeek);
if (!isCommonHour || !isCommonDay) {
const severity = (!isCommonHour && !isCommonDay) ? 'HIGH' : 'LOW';
return {
type: 'TIME_PATTERN_ANOMALY',
severity,
message: `检测到异常时间的访问:${eventHour}:00,星期${dayOfWeek}`,
details: {
eventHour,
dayOfWeek,
commonHours: profile.commonHours,
commonDays: profile.commonDays
}
};
}
return null;
}
}
// 接口定义
interface AuthenticationResponse {
success: boolean;
requiresMFA?: boolean;
availableMethods?: AuthMethod['type'][];
tempToken?: string;
accessToken?: string;
sessionId?: string;
expiresIn?: number;
message: string;
}
interface SecurityStats {
activeSessionsCount: number;
lockedAccountsCount: number;
averageRiskScore: number;
highRiskSessionsCount: number;
mfaUsageRate: number;
recentFailedAttempts: number;
}
interface SecurityEvent {
userId: string;
eventType: string;
ipAddress: string;
userAgent: string;
timestamp: number;
success: boolean;
riskScore?: number;
}
interface AnomalyAlert {
type: string;
severity: 'LOW' | 'MEDIUM' | 'HIGH' | 'CRITICAL';
message: string;
details: any;
}
interface UserBaselineProfile {
userId: string;
commonLocations: GeographicLocation[];
commonHours: number[];
commonDays: number[];
averageSessionDuration: number;
commonDevices: string[];
}
interface GeographicLocation {
country: string;
city: string;
latitude: number;
longitude: number;
}
class AuthenticationError extends Error {
constructor(message: string, public code: string) {
super(message);
this.name = 'AuthenticationError';
}
}
// 使用示例
const authenticator = new MultiFactorAuthenticator({
requireMFA: true,
minAuthMethods: 2,
maxFailedAttempts: 5,
sessionTimeout: 3600, // 1小时
requireReauth: ['delete_account', 'change_password', 'transfer_funds'],
geoRestrictions: {
allowedCountries: ['CN', 'US', 'GB'],
blockSuspiciousLocations: true
}
}, process.env.JWT_SECRET!);
// 启动安全监控
authenticator.startSecurityMonitoring();
export { MultiFactorAuthenticator, SecurityAnomalyDetector };
输入验证与清理:守护数据大门 🚿
全面的输入验证框架
python
# 🌟 输入验证与清理系统
import re
import html
import json
import urllib.parse
from typing import Any, Dict, List, Optional, Union, Callable
from dataclasses import dataclass
from enum import Enum
import bleach
from datetime import datetime
import ipaddress
import email_validator
class ValidationLevel(Enum):
"""验证严格程度"""
STRICT = "strict" # 严格模式:最安全,可能拒绝合法输入
BALANCED = "balanced" # 平衡模式:安全与可用性平衡
PERMISSIVE = "permissive" # 宽松模式:允许更多输入
@dataclass
class ValidationResult:
"""验证结果"""
is_valid: bool
cleaned_value: Any
errors: List[str]
warnings: List[str]
risk_score: int # 0-100,风险评分
class InputValidator:
"""通用输入验证器"""
def __init__(self, level: ValidationLevel = ValidationLevel.BALANCED):
self.level = level
self.custom_rules: Dict[str, List[Callable]] = {}
# 危险关键词列表
self.sql_injection_patterns = [
r'(\bunion\b|\bselect\b|\binsert\b|\bupdate\b|\bdelete\b|\bdrop\b)',
r'(\bexec\b|\bexecute\b|\bsp_\w+)',
r'(\'.*?\'|\".*?\")\s*;\s*--',
r'(\bor\b|\band\b)\s+[\w\'\"]+\s*=\s*[\w\'\"]+',
]
self.xss_patterns = [
r'<script[^>]*>.*?</script>',
r'<iframe[^>]*>.*?</iframe>',
r'javascript\s*:',
r'on\w+\s*=',
r'<object[^>]*>.*?</object>',
r'<embed[^>]*>.*?</embed>',
]
self.command_injection_patterns = [
r'[;&|`$(){}[\]\\]',
r'\b(rm|del|format|fdisk)\b',
r'(nc|netcat|wget|curl)\s+',
]
# 允许的HTML标签(用于富文本清理)
self.allowed_html_tags = {
'p', 'br', 'strong', 'b', 'em', 'i', 'u',
'ol', 'ul', 'li', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
'blockquote', 'code', 'pre'
}
self.allowed_html_attributes = {
'*': ['class', 'id'],
'a': ['href', 'title', 'rel'],
'img': ['src', 'alt', 'width', 'height'],
}
def validate_string(
self,
value: str,
max_length: int = 1000,
min_length: int = 0,
pattern: Optional[str] = None,
allow_html: bool = False,
check_encoding: bool = True
) -> ValidationResult:
"""字符串验证"""
errors = []
warnings = []
risk_score = 0
# 基本长度检查
if len(value) < min_length:
errors.append(f"字符串长度不能少于{min_length}个字符")
if len(value) > max_length:
if self.level == ValidationLevel.STRICT:
errors.append(f"字符串长度不能超过{max_length}个字符")
else:
value = value[:max_length]
warnings.append(f"字符串已截断到{max_length}个字符")
# 编码检查
if check_encoding:
try:
value.encode('utf-8')
except UnicodeEncodeError:
errors.append("包含不支持的字符编码")
risk_score += 20
# 恶意内容检测
original_value = value
# SQL注入检测
sql_risk = self._detect_sql_injection(value)
risk_score += sql_risk
if sql_risk > 50:
errors.append("检测到潜在的SQL注入尝试")
# XSS检测
xss_risk = self._detect_xss(value)
risk_score += xss_risk
if xss_risk > 50:
if allow_html:
# 清理HTML但保留安全标签
value = self._sanitize_html(value)
warnings.append("HTML内容已清理")
else:
errors.append("检测到潜在的XSS攻击")
# 命令注入检测
cmd_risk = self._detect_command_injection(value)
risk_score += cmd_risk
if cmd_risk > 30:
errors.append("检测到潜在的命令注入尝试")
# 正则表达式验证
if pattern and not re.match(pattern, value, re.IGNORECASE):
errors.append(f"格式不符合要求:{pattern}")
# 字符白名单检查(严格模式)
if self.level == ValidationLevel.STRICT:
dangerous_chars = self._find_dangerous_characters(value)
if dangerous_chars:
errors.append(f"包含危险字符:{', '.join(dangerous_chars)}")
risk_score += len(dangerous_chars) * 5
# HTML实体编码(如果不允许HTML)
if not allow_html and xss_risk > 0:
value = html.escape(value, quote=True)
return ValidationResult(
is_valid=len(errors) == 0,
cleaned_value=value,
errors=errors,
warnings=warnings,
risk_score=min(100, risk_score)
)
def validate_email(self, email: str) -> ValidationResult:
"""邮箱验证"""
errors = []
warnings = []
risk_score = 0
try:
# 使用专业的邮箱验证库
validated_email = email_validator.validate_email(email, check_deliverability=False)
cleaned_email = validated_email.email
# 检查是否为一次性邮箱
if self._is_disposable_email(cleaned_email):
if self.level == ValidationLevel.STRICT:
errors.append("不允许使用一次性邮箱")
else:
warnings.append("检测到一次性邮箱")
risk_score += 30
# 检查邮箱域名信誉
domain_risk = self._check_email_domain_reputation(cleaned_email)
risk_score += domain_risk
return ValidationResult(
is_valid=len(errors) == 0,
cleaned_value=cleaned_email,
errors=errors,
warnings=warnings,
risk_score=risk_score
)
except email_validator.EmailNotValidError as e:
return ValidationResult(
is_valid=False,
cleaned_value=email,
errors=[f"邮箱格式无效:{str(e)}"],
warnings=[],
risk_score=0
)
def validate_url(self, url: str, allowed_schemes: List[str] = None) -> ValidationResult:
"""URL验证"""
errors = []
warnings = []
risk_score = 0
if allowed_schemes is None:
allowed_schemes = ['http', 'https']
try:
# URL解析
parsed = urllib.parse.urlparse(url)
# 协议检查
if parsed.scheme not in allowed_schemes:
errors.append(f"不允许的URL协议:{parsed.scheme}")
# 主机名检查
if not parsed.netloc:
errors.append("缺少主机名")
else:
# 检查是否为恶意域名
domain_risk = self._check_domain_reputation(parsed.netloc)
risk_score += domain_risk
if domain_risk > 70:
errors.append("检测到恶意域名")
elif domain_risk > 40:
warnings.append("域名可能存在风险")
# 路径检查
if parsed.path:
path_risk = self._check_url_path(parsed.path)
risk_score += path_risk
# 重构清理后的URL
cleaned_url = urllib.parse.urlunparse(parsed)
return ValidationResult(
is_valid=len(errors) == 0,
cleaned_value=cleaned_url,
errors=errors,
warnings=warnings,
risk_score=risk_score
)
except Exception as e:
return ValidationResult(
is_valid=False,
cleaned_value=url,
errors=[f"URL格式错误:{str(e)}"],
warnings=[],
risk_score=0
)
def validate_ip_address(self, ip: str, allow_private: bool = True) -> ValidationResult:
"""IP地址验证"""
errors = []
warnings = []
risk_score = 0
try:
ip_obj = ipaddress.ip_address(ip)
# 私有地址检查
if ip_obj.is_private and not allow_private:
errors.append("不允许私有IP地址")
# 特殊地址检查
if ip_obj.is_loopback:
warnings.append("环回地址")
risk_score += 10
if ip_obj.is_multicast:
warnings.append("组播地址")
risk_score += 15
# IP信誉检查
reputation_risk = self._check_ip_reputation(str(ip_obj))
risk_score += reputation_risk
if reputation_risk > 70:
errors.append("IP地址被标记为恶意")
elif reputation_risk > 40:
warnings.append("IP地址可能存在风险")
return ValidationResult(
is_valid=len(errors) == 0,
cleaned_value=str(ip_obj),
errors=errors,
warnings=warnings,
risk_score=risk_score
)
except ValueError:
return ValidationResult(
is_valid=False,
cleaned_value=ip,
errors=["IP地址格式无效"],
warnings=[],
risk_score=0
)
def validate_json(self, json_str: str, max_depth: int = 10, max_keys: int = 100) -> ValidationResult:
"""JSON验证"""
errors = []
warnings = []
risk_score = 0
try:
# 解析JSON
data = json.loads(json_str)
# 检查嵌套深度
actual_depth = self._get_json_depth(data)
if actual_depth > max_depth:
errors.append(f"JSON嵌套深度过深:{actual_depth} > {max_depth}")
risk_score += 30
# 检查键数量
key_count = self._count_json_keys(data)
if key_count > max_keys:
errors.append(f"JSON键数量过多:{key_count} > {max_keys}")
risk_score += 20
# 检查危险内容
dangerous_content = self._scan_json_for_dangers(data)
if dangerous_content:
risk_score += len(dangerous_content) * 15
warnings.extend(dangerous_content)
# 清理后重新序列化
cleaned_json = json.dumps(data, ensure_ascii=False, separators=(',', ':'))
return ValidationResult(
is_valid=len(errors) == 0,
cleaned_value=cleaned_json,
errors=errors,
warnings=warnings,
risk_score=risk_score
)
except json.JSONDecodeError as e:
return ValidationResult(
is_valid=False,
cleaned_value=json_str,
errors=[f"JSON格式错误:{str(e)}"],
warnings=[],
risk_score=0
)
def _detect_sql_injection(self, value: str) -> int:
"""SQL注入检测"""
risk_score = 0
value_lower = value.lower()
for pattern in self.sql_injection_patterns:
matches = re.findall(pattern, value_lower, re.IGNORECASE)
if matches:
risk_score += len(matches) * 25
# 检查SQL关键词密度
sql_keywords = ['select', 'insert', 'update', 'delete', 'union', 'drop', 'exec']
keyword_count = sum(1 for keyword in sql_keywords if keyword in value_lower)
if keyword_count > 2:
risk_score += keyword_count * 10
return min(100, risk_score)
def _detect_xss(self, value: str) -> int:
"""XSS检测"""
risk_score = 0
for pattern in self.xss_patterns:
matches = re.findall(pattern, value, re.IGNORECASE | re.DOTALL)
if matches:
risk_score += len(matches) * 30
# 检查事件处理器
event_handlers = re.findall(r'on\w+\s*=', value, re.IGNORECASE)
risk_score += len(event_handlers) * 20
# 检查JavaScript URL
js_urls = re.findall(r'javascript\s*:', value, re.IGNORECASE)
risk_score += len(js_urls) * 25
return min(100, risk_score)
def _detect_command_injection(self, value: str) -> int:
"""命令注入检测"""
risk_score = 0
for pattern in self.command_injection_patterns:
matches = re.findall(pattern, value, re.IGNORECASE)
if matches:
risk_score += len(matches) * 20
return min(100, risk_score)
def _sanitize_html(self, html_content: str) -> str:
"""HTML清理"""
return bleach.clean(
html_content,
tags=self.allowed_html_tags,
attributes=self.allowed_html_attributes,
strip=True,
strip_comments=True
)
def _find_dangerous_characters(self, value: str) -> List[str]:
"""查找危险字符"""
dangerous_chars = set()
# 控制字符
for char in value:
if ord(char) < 32 and char not in ['\t', '\n', '\r']:
dangerous_chars.add(f"控制字符({ord(char)})")
# 特殊字符
special_chars = ['<', '>', '"', "'", '&', '\x00']
for char in special_chars:
if char in value:
dangerous_chars.add(char)
return list(dangerous_chars)
# 高级验证装饰器
def validate_input(validator_config: Dict[str, Any]):
"""输入验证装饰器"""
def decorator(func):
def wrapper(*args, **kwargs):
validator = InputValidator()
# 验证函数参数
import inspect
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
for param_name, value in bound_args.arguments.items():
if param_name in validator_config:
config = validator_config[param_name]
if config['type'] == 'string':
result = validator.validate_string(
str(value),
max_length=config.get('max_length', 1000),
min_length=config.get('min_length', 0),
pattern=config.get('pattern'),
allow_html=config.get('allow_html', False)
)
elif config['type'] == 'email':
result = validator.validate_email(str(value))
elif config['type'] == 'url':
result = validator.validate_url(str(value))
elif config['type'] == 'ip':
result = validator.validate_ip_address(str(value))
else:
continue
if not result.is_valid:
raise ValueError(f"参数 {param_name} 验证失败: {', '.join(result.errors)}")
# 更新为清理后的值
bound_args.arguments[param_name] = result.cleaned_value
# 调用原函数
return func(*bound_args.args, **bound_args.kwargs)
return wrapper
return decorator
# 使用示例
@validate_input({
'username': {
'type': 'string',
'min_length': 3,
'max_length': 20,
'pattern': r'^[a-zA-Z0-9_]+$'
},
'email': {
'type': 'email'
},
'website': {
'type': 'url'
}
})
def create_user_profile(username: str, email: str, website: str = None):
"""创建用户资料(带输入验证)"""
print(f"创建用户:{username}, 邮箱:{email}, 网站:{website}")
return {"username": username, "email": email, "website": website}
# 批量验证工具
class BatchValidator:
"""批量输入验证"""
def __init__(self):
self.validator = InputValidator()
self.results: List[ValidationResult] = []
def validate_batch(self, data: Dict[str, Any], schema: Dict[str, Dict]) -> Dict[str, ValidationResult]:
"""批量验证数据"""
results = {}
for field_name, field_schema in schema.items():
if field_name not in data:
if field_schema.get('required', False):
results[field_name] = ValidationResult(
is_valid=False,
cleaned_value=None,
errors=[f"必填字段 {field_name} 缺失"],
warnings=[],
risk_score=0
)
continue
value = data[field_name]
field_type = field_schema['type']
if field_type == 'string':
result = self.validator.validate_string(
str(value),
max_length=field_schema.get('max_length', 1000),
min_length=field_schema.get('min_length', 0),
pattern=field_schema.get('pattern'),
allow_html=field_schema.get('allow_html', False)
)
elif field_type == 'email':
result = self.validator.validate_email(str(value))
elif field_type == 'url':
result = self.validator.validate_url(str(value))
elif field_type == 'ip':
result = self.validator.validate_ip_address(str(value))
elif field_type == 'json':
result = self.validator.validate_json(str(value))
else:
result = ValidationResult(
is_valid=True,
cleaned_value=value,
errors=[],
warnings=[],
risk_score=0
)
results[field_name] = result
return results
def get_validation_summary(self, results: Dict[str, ValidationResult]) -> Dict[str, Any]:
"""获取验证摘要"""
total_fields = len(results)
valid_fields = sum(1 for r in results.values() if r.is_valid)
total_errors = sum(len(r.errors) for r in results.values())
total_warnings = sum(len(r.warnings) for r in results.values())
avg_risk_score = sum(r.risk_score for r in results.values()) / max(1, total_fields)
return {
'total_fields': total_fields,
'valid_fields': valid_fields,
'invalid_fields': total_fields - valid_fields,
'validation_rate': valid_fields / total_fields if total_fields > 0 else 0,
'total_errors': total_errors,
'total_warnings': total_warnings,
'average_risk_score': avg_risk_score,
'high_risk_fields': [
name for name, result in results.items()
if result.risk_score > 70
]
}
# 测试示例
if __name__ == "__main__":
# 创建验证器
validator = InputValidator(ValidationLevel.BALANCED)
# 测试各种输入
test_cases = [
("正常文本", "这是一个正常的文本输入"),
("SQL注入尝试", "'; DROP TABLE users; --"),
("XSS攻击", "<script>alert('XSS')</script>"),
("正常邮箱", "user@example.com"),
("正常URL", "https://www.example.com/path?param=value"),
("IP地址", "192.168.1.1"),
]
print("🧪 输入验证测试结果:\n")
for test_name, test_input in test_cases:
print(f"📋 测试:{test_name}")
print(f"输入:{test_input}")
# 字符串验证
result = validator.validate_string(test_input, max_length=100)
status = "✅ 通过" if result.is_valid else "❌ 失败"
print(f"结果:{status}")
if result.errors:
print(f"错误:{', '.join(result.errors)}")
if result.warnings:
print(f"警告:{', '.join(result.warnings)}")
print(f"风险评分:{result.risk_score}/100")
print(f"清理后:{result.cleaned_value}")
print("-" * 50)
小结
安全实现的核心防护措施:
🛡️ 多重防护体系
- 身份认证 - 多因素认证确保用户身份真实性
- 权限控制 - 基于角色的访问控制系统
- 输入验证 - 全面的输入清理和验证机制
- 异常监控 - 实时检测和响应安全威胁
- 会话管理 - 安全的会话生命周期管理
💡 安全要点
- 采用深度防御策略,多层安全控制
- 实施零信任架构,验证每个请求
- 建立完善的日志和监控体系
- 定期进行安全评估和渗透测试
- 保持安全组件的及时更新
🔒 安全哲学:安全是一个持续的过程,不是一次性的任务。需要在便利性和安全性之间找到合适的平衡点。
下一节:测试策略完全指南 - 构建坚如磐石的测试体系