8.3 工具设计最佳实践 🔧
"好的工具应该像筷子一样 - 简单易用,但功能强大。"
想象一下,你手中有一把瑞士军刀和一把专业菜刀。当你需要切菜时,哪个更好用?答案显而易见 - 专业菜刀!MCP工具设计也是同样的道理:专注、高效、易用。
单一职责原则 🎯
为什么单一职责如此重要?
单一职责原则就像是"术业有专攻"的编程版本。一个工具只做一件事,但要把这件事做到极致。
python
# ❌ 反面教材:万能工具
class SuperTool:
"""这个工具什么都能做,但什么都做不好"""
def execute(self, operation, **kwargs):
if operation == "weather":
# 天气查询逻辑
return self.get_weather(kwargs.get('location'))
elif operation == "calculator":
# 计算器逻辑
return self.calculate(kwargs.get('expression'))
elif operation == "translate":
# 翻译逻辑
return self.translate(kwargs.get('text'), kwargs.get('target_lang'))
elif operation == "file_read":
# 文件读取逻辑
return self.read_file(kwargs.get('path'))
# ... 还有20种其他操作
else:
return "不支持的操作"
# 问题:
# 1. 代码混乱,难以维护
# 2. 测试困难
# 3. 单个功能出错影响所有功能
# 4. 参数混乱,容易出错
python
# ✅ 正确示例:专注的工具
class WeatherQueryTool:
"""专门用于天气查询的工具 - 只做一件事,但做得很好"""
def __init__(self, weather_service: WeatherService):
self.weather_service = weather_service
self.cache = WeatherCache(ttl=300) # 5分钟缓存
self.logger = logging.getLogger(__name__)
def get_name(self) -> str:
return "weather_query"
def get_description(self) -> str:
return "查询指定地区的天气信息,支持当前天气和未来7天预报"
def get_schema(self) -> dict:
return {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "城市名称或坐标,支持中英文",
"examples": ["北京", "New York", "40.7128,-74.0060"]
},
"forecast_days": {
"type": "integer",
"description": "预报天数(0为当前天气,1-7为未来天数)",
"minimum": 0,
"maximum": 7,
"default": 0
},
"units": {
"type": "string",
"description": "温度单位",
"enum": ["celsius", "fahrenheit"],
"default": "celsius"
},
"language": {
"type": "string",
"description": "返回结果的语言",
"enum": ["zh", "en"],
"default": "zh"
}
},
"required": ["location"]
}
async def execute(self, parameters: dict) -> dict:
"""
执行天气查询
专注于天气查询的核心逻辑:
1. 参数验证
2. 缓存检查
3. 天气数据获取
4. 结果格式化
5. 错误处理
"""
try:
# 1. 参数验证(专门的验证逻辑)
validated_params = self._validate_parameters(parameters)
# 2. 检查缓存
cache_key = self._generate_cache_key(validated_params)
cached_result = await self.cache.get(cache_key)
if cached_result:
self.logger.info(f"缓存命中:{cache_key}")
return cached_result
# 3. 调用天气服务
weather_data = await self.weather_service.get_weather(
location=validated_params['location'],
forecast_days=validated_params['forecast_days'],
units=validated_params['units']
)
# 4. 格式化结果(专门的格式化逻辑)
formatted_result = self._format_weather_data(
weather_data, validated_params['language']
)
# 5. 缓存结果
await self.cache.set(cache_key, formatted_result)
return formatted_result
except LocationNotFoundError as e:
raise ToolExecutionError(f"未找到位置:{e.location}")
except WeatherServiceError as e:
raise ToolExecutionError(f"天气服务错误:{str(e)}")
except Exception as e:
self.logger.error(f"天气查询意外错误:{str(e)}")
raise ToolExecutionError("天气查询服务暂时不可用")
def _validate_parameters(self, params: dict) -> dict:
"""专门的参数验证逻辑"""
location = params.get('location', '').strip()
if not location:
raise ParameterError("location 参数不能为空")
forecast_days = params.get('forecast_days', 0)
if not isinstance(forecast_days, int) or forecast_days < 0 or forecast_days > 7:
raise ParameterError("forecast_days 必须是0-7之间的整数")
return {
'location': location,
'forecast_days': forecast_days,
'units': params.get('units', 'celsius'),
'language': params.get('language', 'zh')
}
def _format_weather_data(self, data: WeatherData, language: str) -> dict:
"""专门的数据格式化逻辑"""
if language == 'zh':
return {
"位置": data.location,
"当前温度": f"{data.current.temperature}°C",
"天气描述": data.current.description_zh,
"湿度": f"{data.current.humidity}%",
"风速": f"{data.current.wind_speed} km/h",
"更新时间": data.last_updated.isoformat(),
"预报": [
{
"日期": forecast.date.strftime("%Y-%m-%d"),
"最高温": f"{forecast.high}°C",
"最低温": f"{forecast.low}°C",
"天气": forecast.description_zh
}
for forecast in data.forecasts
] if data.forecasts else None
}
else:
return {
"location": data.location,
"current_temperature": f"{data.current.temperature}°{'C' if data.units == 'celsius' else 'F'}",
"description": data.current.description_en,
"humidity": f"{data.current.humidity}%",
"wind_speed": f"{data.current.wind_speed} km/h",
"last_updated": data.last_updated.isoformat(),
"forecast": [
{
"date": forecast.date.strftime("%Y-%m-%d"),
"high": f"{forecast.high}°{'C' if data.units == 'celsius' else 'F'}",
"low": f"{forecast.low}°{'C' if data.units == 'celsius' else 'F'}",
"description": forecast.description_en
}
for forecast in data.forecasts
] if data.forecasts else None
}
单一职责的好处
- 易于理解 📖 - 代码逻辑清晰,新人也能快速上手
- 易于测试 🧪 - 测试场景明确,覆盖率容易达到
- 易于维护 🔧 - 修改某个功能不会影响其他功能
- 易于复用 ♻️ - 其他地方可以直接使用
- 易于扩展 📈 - 需要新功能时独立开发
一致的错误处理 🚨
建立错误处理体系
好的错误处理就像是系统的"免疫系统",能够优雅地处理各种异常情况。
typescript
// 🌟 错误处理框架
enum ErrorType {
VALIDATION_ERROR = "VALIDATION_ERROR",
AUTHENTICATION_ERROR = "AUTHENTICATION_ERROR",
AUTHORIZATION_ERROR = "AUTHORIZATION_ERROR",
RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND",
EXTERNAL_SERVICE_ERROR = "EXTERNAL_SERVICE_ERROR",
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED",
INTERNAL_ERROR = "INTERNAL_ERROR"
}
class ToolError extends Error {
public readonly errorType: ErrorType;
public readonly userMessage: string;
public readonly technicalDetails: string;
public readonly errorCode: string;
public readonly retryable: boolean;
public readonly context: Record<string, any>;
constructor(
errorType: ErrorType,
userMessage: string,
technicalDetails: string,
options: {
errorCode?: string;
retryable?: boolean;
context?: Record<string, any>;
cause?: Error;
} = {}
) {
super(userMessage);
this.name = 'ToolError';
this.errorType = errorType;
this.userMessage = userMessage;
this.technicalDetails = technicalDetails;
this.errorCode = options.errorCode || this.generateErrorCode();
this.retryable = options.retryable || false;
this.context = options.context || {};
if (options.cause) {
this.cause = options.cause;
}
}
private generateErrorCode(): string {
return `${this.errorType}_${Date.now().toString(36).toUpperCase()}`;
}
toJSON() {
return {
errorType: this.errorType,
userMessage: this.userMessage,
technicalDetails: this.technicalDetails,
errorCode: this.errorCode,
retryable: this.retryable,
context: this.context,
timestamp: new Date().toISOString()
};
}
}
// 具体的错误类型
class ValidationError extends ToolError {
constructor(message: string, field?: string, value?: any) {
super(
ErrorType.VALIDATION_ERROR,
message,
`参数验证失败:${field ? `字段 ${field}` : '未知字段'}`,
{
context: { field, value },
retryable: false
}
);
}
}
class ExternalServiceError extends ToolError {
constructor(serviceName: string, originalError: Error, retryable: boolean = true) {
super(
ErrorType.EXTERNAL_SERVICE_ERROR,
`${serviceName}服务暂时不可用,请稍后重试`,
`外部服务调用失败:${originalError.message}`,
{
context: { serviceName, originalError: originalError.message },
retryable,
cause: originalError
}
);
}
}
// 实际工具中的错误处理
class DatabaseQueryTool {
private logger = new Logger('DatabaseQueryTool');
async execute(parameters: any): Promise<any> {
const startTime = Date.now();
const executionId = this.generateExecutionId();
try {
// 记录执行开始
this.logger.info('开始执行数据库查询', {
executionId,
parameters: this.sanitizeParametersForLogging(parameters)
});
// 参数验证
const validatedParams = this.validateParameters(parameters);
// 安全检查
this.performSecurityCheck(validatedParams);
// 执行查询
const result = await this.executeQuery(validatedParams);
// 记录成功
const executionTime = Date.now() - startTime;
this.logger.info('数据库查询执行成功', {
executionId,
executionTime,
resultCount: Array.isArray(result) ? result.length : 1
});
return result;
} catch (error) {
const executionTime = Date.now() - startTime;
// 根据错误类型进行分类处理
const handledError = this.handleError(error, executionId, executionTime);
// 记录错误
this.logger.error('数据库查询执行失败', {
executionId,
executionTime,
error: handledError.toJSON()
});
throw handledError;
}
}
private handleError(error: any, executionId: string, executionTime: number): ToolError {
// 如果已经是我们的错误类型,直接返回
if (error instanceof ToolError) {
return error;
}
// 根据错误类型分类处理
if (error.name === 'ValidationError' || error.message.includes('validation')) {
return new ValidationError(
error.message || '参数验证失败',
error.field,
error.value
);
}
if (error.code === 'ECONNREFUSED' || error.code === 'ETIMEDOUT') {
return new ExternalServiceError(
'数据库',
error,
true // 网络错误通常可以重试
);
}
if (error.code === 'ER_ACCESS_DENIED_ERROR') {
return new ToolError(
ErrorType.AUTHORIZATION_ERROR,
'数据库访问权限不足',
`数据库认证失败:${error.message}`,
{ retryable: false }
);
}
if (error.code === 'ER_NO_SUCH_TABLE') {
return new ToolError(
ErrorType.RESOURCE_NOT_FOUND,
'请求的数据表不存在',
`表不存在:${error.message}`,
{ retryable: false }
);
}
// 未知错误,包装为内部错误
return new ToolError(
ErrorType.INTERNAL_ERROR,
'系统内部错误,请联系管理员',
`意外错误:${error.message}`,
{
context: { executionId, executionTime },
retryable: false,
cause: error
}
);
}
private validateParameters(params: any): any {
if (!params.query || typeof params.query !== 'string') {
throw new ValidationError(
'query参数是必需的,且必须是字符串类型',
'query',
params.query
);
}
if (params.query.length > 10000) {
throw new ValidationError(
'SQL查询长度不能超过10000个字符',
'query',
`长度:${params.query.length}`
);
}
return params;
}
private performSecurityCheck(params: any): void {
const query = params.query.toLowerCase();
// 检查危险的SQL操作
const dangerousPatterns = [
/drop\s+table/i,
/delete\s+from.*where\s+1\s*=\s*1/i,
/truncate\s+table/i,
/alter\s+table/i
];
for (const pattern of dangerousPatterns) {
if (pattern.test(query)) {
throw new ToolError(
ErrorType.AUTHORIZATION_ERROR,
'查询包含危险操作,已被阻止',
`检测到危险SQL模式:${pattern}`,
{ retryable: false }
);
}
}
}
private sanitizeParametersForLogging(params: any): any {
// 移除敏感信息用于日志记录
const sanitized = { ...params };
if (sanitized.password) sanitized.password = '***';
if (sanitized.apiKey) sanitized.apiKey = '***';
if (sanitized.token) sanitized.token = '***';
return sanitized;
}
}
参数验证最佳实践 ✅
全面的参数验证体系
参数验证就像是系统的"质检员",确保进入系统的每个参数都符合要求。
java
// 🌟 Java参数验证框架
public class ParameterValidator {
private static final Logger logger = LoggerFactory.getLogger(ParameterValidator.class);
public static class ValidationRule<T> {
private final String fieldName;
private final Predicate<T> validator;
private final String errorMessage;
private final boolean required;
public ValidationRule(String fieldName, Predicate<T> validator,
String errorMessage, boolean required) {
this.fieldName = fieldName;
this.validator = validator;
this.errorMessage = errorMessage;
this.required = required;
}
// Getters...
}
public static class ValidationResult {
private final boolean isValid;
private final List<String> errors;
private final Map<String, Object> sanitizedValues;
// 构造函数和方法...
}
public static class ValidatorBuilder {
private final Map<String, List<ValidationRule<?>>> rules = new HashMap<>();
public ValidatorBuilder field(String fieldName) {
this.currentField = fieldName;
return this;
}
public ValidatorBuilder required() {
getCurrentRules().add(new ValidationRule<>(
currentField,
Objects::nonNull,
"字段 " + currentField + " 是必需的",
true
));
return this;
}
public ValidatorBuilder string() {
getCurrentRules().add(new ValidationRule<>(
currentField,
value -> value instanceof String,
"字段 " + currentField + " 必须是字符串类型",
false
));
return this;
}
public ValidatorBuilder minLength(int minLength) {
getCurrentRules().add(new ValidationRule<String>(
currentField,
value -> value.length() >= minLength,
String.format("字段 %s 长度不能少于 %d 个字符", currentField, minLength),
false
));
return this;
}
public ValidatorBuilder maxLength(int maxLength) {
getCurrentRules().add(new ValidationRule<String>(
currentField,
value -> value.length() <= maxLength,
String.format("字段 %s 长度不能超过 %d 个字符", currentField, maxLength),
false
));
return this;
}
public ValidatorBuilder pattern(String regex, String message) {
Pattern pattern = Pattern.compile(regex);
getCurrentRules().add(new ValidationRule<String>(
currentField,
value -> pattern.matcher(value).matches(),
message != null ? message :
String.format("字段 %s 格式不正确", currentField),
false
));
return this;
}
public ValidatorBuilder email() {
return pattern(
"^[A-Za-z0-9+_.-]+@([A-Za-z0-9.-]+\\.[A-Za-z]{2,})$",
"请输入有效的邮箱地址"
);
}
public ValidatorBuilder integer() {
getCurrentRules().add(new ValidationRule<>(
currentField,
value -> value instanceof Integer ||
(value instanceof String && value.toString().matches("-?\\d+")),
"字段 " + currentField + " 必须是整数",
false
));
return this;
}
public ValidatorBuilder range(int min, int max) {
getCurrentRules().add(new ValidationRule<Integer>(
currentField,
value -> value >= min && value <= max,
String.format("字段 %s 必须在 %d 到 %d 之间", currentField, min, max),
false
));
return this;
}
public ValidatorBuilder custom(Predicate<Object> validator, String errorMessage) {
getCurrentRules().add(new ValidationRule<>(
currentField,
validator,
errorMessage,
false
));
return this;
}
public ParameterValidator build() {
return new ParameterValidator(this.rules);
}
}
// 实际的文件操作工具验证示例
public static ParameterValidator createFileOperationValidator() {
return new ValidatorBuilder()
.field("operation")
.required()
.string()
.custom(value -> Arrays.asList("read", "write", "delete", "list").contains(value),
"operation 必须是 read, write, delete, list 中的一个")
.field("path")
.required()
.string()
.minLength(1)
.maxLength(500)
.custom(ParameterValidator::isValidFilePath, "文件路径包含非法字符")
.custom(ParameterValidator::isPathWithinAllowedDirectories, "文件路径超出允许范围")
.field("content")
.string()
.maxLength(1024 * 1024) // 1MB限制
.custom(value -> {
// 只有写操作才需要content
return true; // 这里应该在上下文中检查operation
}, "写操作需要提供content参数")
.field("encoding")
.string()
.custom(value -> Arrays.asList("utf-8", "gbk", "ascii").contains(value.toLowerCase()),
"encoding 必须是 utf-8, gbk, ascii 中的一个")
.build();
}
private static boolean isValidFilePath(Object path) {
if (!(path instanceof String)) return false;
String pathStr = (String) path;
// 检查危险字符
String[] dangerousPatterns = {"../", "..\\", "<", ">", "|", "?", "*"};
for (String pattern : dangerousPatterns) {
if (pathStr.contains(pattern)) return false;
}
return true;
}
private static boolean isPathWithinAllowedDirectories(Object path) {
if (!(path instanceof String)) return false;
String pathStr = (String) path;
try {
Path normalizedPath = Paths.get(pathStr).normalize();
String normalizedStr = normalizedPath.toString();
// 检查是否在允许的目录内
List<String> allowedDirectories = Arrays.asList(
"/tmp/mcp/",
"/var/mcp/data/",
"C:\\temp\\mcp\\",
"C:\\data\\mcp\\"
);
return allowedDirectories.stream()
.anyMatch(allowedDir -> normalizedStr.startsWith(allowedDir));
} catch (Exception e) {
return false;
}
}
}
// 使用示例
public class FileOperationTool implements Tool {
private final ParameterValidator validator;
public FileOperationTool() {
this.validator = ParameterValidator.createFileOperationValidator();
}
@Override
public ToolResponse execute(ToolRequest request) {
// 验证参数
ValidationResult validation = validator.validate(request.getParameters());
if (!validation.isValid()) {
throw new ToolExecutionException(
"参数验证失败:" + String.join(", ", validation.getErrors())
);
}
// 使用验证和清理后的参数
Map<String, Object> safeParams = validation.getSanitizedValues();
String operation = (String) safeParams.get("operation");
String path = (String) safeParams.get("path");
// 继续执行工具逻辑...
switch (operation) {
case "read":
return executeRead(path);
case "write":
String content = (String) safeParams.get("content");
return executeWrite(path, content);
// ... 其他操作
}
}
}
依赖注入与可测试性 🧪
让工具变得可测试
可测试的代码就像是透明的玻璃房子,里面的每个部分都能清楚地看到和验证。
csharp
// 🌟 可测试的工具设计
public interface IWeatherService
{
Task<WeatherData> GetWeatherAsync(string location, int days = 1);
}
public interface ICacheService
{
Task<T> GetAsync<T>(string key);
Task SetAsync<T>(string key, T value, TimeSpan expiry);
}
public interface ILogger<T>
{
void LogInformation(string message, params object[] args);
void LogError(Exception exception, string message, params object[] args);
}
// 主要工具类 - 注意所有依赖都是通过接口注入的
public class WeatherForecastTool : ITool
{
private readonly IWeatherService _weatherService;
private readonly ICacheService _cacheService;
private readonly ILogger<WeatherForecastTool> _logger;
// 🎯 关键:通过构造函数注入所有依赖
public WeatherForecastTool(
IWeatherService weatherService,
ICacheService cacheService,
ILogger<WeatherForecastTool> logger)
{
_weatherService = weatherService ?? throw new ArgumentNullException(nameof(weatherService));
_cacheService = cacheService ?? throw new ArgumentNullException(nameof(cacheService));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
}
public string Name => "weather_forecast";
public string Description => "获取天气预报信息";
public async Task<ToolResponse> ExecuteAsync(ToolRequest request)
{
var location = request.Parameters["location"].ToString();
var days = request.Parameters.ContainsKey("days")
? Convert.ToInt32(request.Parameters["days"])
: 1;
_logger.LogInformation("开始获取天气预报:{Location}, {Days}天", location, days);
try
{
// 检查缓存
var cacheKey = $"weather:{location}:{days}";
var cachedResult = await _cacheService.GetAsync<WeatherData>(cacheKey);
if (cachedResult != null)
{
_logger.LogInformation("使用缓存的天气数据:{Location}", location);
return CreateResponse(cachedResult);
}
// 从服务获取数据
var weatherData = await _weatherService.GetWeatherAsync(location, days);
// 缓存结果
await _cacheService.SetAsync(cacheKey, weatherData, TimeSpan.FromMinutes(30));
_logger.LogInformation("成功获取天气预报:{Location}", location);
return CreateResponse(weatherData);
}
catch (LocationNotFoundException ex)
{
_logger.LogError(ex, "位置未找到:{Location}", location);
throw new ToolExecutionException($"未找到位置:{location}");
}
catch (Exception ex)
{
_logger.LogError(ex, "获取天气预报时发生错误:{Location}", location);
throw new ToolExecutionException("天气服务暂时不可用");
}
}
private ToolResponse CreateResponse(WeatherData data)
{
return new ToolResponse
{
Content = new List<ContentItem>
{
new TextContent(JsonSerializer.Serialize(data))
}
};
}
}
// 🧪 单元测试 - 现在变得非常简单
[TestClass]
public class WeatherForecastToolTests
{
private Mock<IWeatherService> _mockWeatherService;
private Mock<ICacheService> _mockCacheService;
private Mock<ILogger<WeatherForecastTool>> _mockLogger;
private WeatherForecastTool _tool;
[TestInitialize]
public void Setup()
{
_mockWeatherService = new Mock<IWeatherService>();
_mockCacheService = new Mock<ICacheService>();
_mockLogger = new Mock<ILogger<WeatherForecastTool>>();
_tool = new WeatherForecastTool(
_mockWeatherService.Object,
_mockCacheService.Object,
_mockLogger.Object
);
}
[TestMethod]
public async Task ExecuteAsync_ValidLocation_ReturnsWeatherData()
{
// Arrange
var location = "北京";
var expectedWeatherData = new WeatherData
{
Location = location,
Temperature = 25,
Description = "晴朗"
};
_mockCacheService
.Setup(c => c.GetAsync<WeatherData>(It.IsAny<string>()))
.ReturnsAsync((WeatherData)null); // 模拟缓存未命中
_mockWeatherService
.Setup(w => w.GetWeatherAsync(location, 1))
.ReturnsAsync(expectedWeatherData);
var request = new ToolRequest
{
Parameters = new Dictionary<string, object>
{
["location"] = location
}
};
// Act
var response = await _tool.ExecuteAsync(request);
// Assert
Assert.IsNotNull(response);
Assert.AreEqual(1, response.Content.Count);
var responseData = JsonSerializer.Deserialize<WeatherData>(
response.Content[0].ToString()
);
Assert.AreEqual(location, responseData.Location);
Assert.AreEqual(25, responseData.Temperature);
// 验证依赖调用
_mockWeatherService.Verify(
w => w.GetWeatherAsync(location, 1),
Times.Once
);
_mockCacheService.Verify(
c => c.SetAsync(It.IsAny<string>(), expectedWeatherData, It.IsAny<TimeSpan>()),
Times.Once
);
}
[TestMethod]
public async Task ExecuteAsync_CacheHit_ReturnsCachedData()
{
// Arrange
var location = "上海";
var cachedWeatherData = new WeatherData
{
Location = location,
Temperature = 20,
Description = "多云"
};
_mockCacheService
.Setup(c => c.GetAsync<WeatherData>(It.IsAny<string>()))
.ReturnsAsync(cachedWeatherData); // 模拟缓存命中
var request = new ToolRequest
{
Parameters = new Dictionary<string, object>
{
["location"] = location
}
};
// Act
var response = await _tool.ExecuteAsync(request);
// Assert
Assert.IsNotNull(response);
var responseData = JsonSerializer.Deserialize<WeatherData>(
response.Content[0].ToString()
);
Assert.AreEqual(location, responseData.Location);
Assert.AreEqual(20, responseData.Temperature);
// 验证没有调用天气服务(因为缓存命中)
_mockWeatherService.Verify(
w => w.GetWeatherAsync(It.IsAny<string>(), It.IsAny<int>()),
Times.Never
);
}
[TestMethod]
public async Task ExecuteAsync_ServiceThrowsException_ThrowsToolExecutionException()
{
// Arrange
var location = "不存在的城市";
_mockCacheService
.Setup(c => c.GetAsync<WeatherData>(It.IsAny<string>()))
.ReturnsAsync((WeatherData)null);
_mockWeatherService
.Setup(w => w.GetWeatherAsync(location, It.IsAny<int>()))
.ThrowsAsync(new LocationNotFoundException(location));
var request = new ToolRequest
{
Parameters = new Dictionary<string, object>
{
["location"] = location
}
};
// Act & Assert
var exception = await Assert.ThrowsExceptionAsync<ToolExecutionException>(
() => _tool.ExecuteAsync(request)
);
Assert.IsTrue(exception.Message.Contains(location));
// 验证日志记录
_mockLogger.Verify(
l => l.LogError(
It.IsAny<Exception>(),
It.Is<string>(s => s.Contains("位置未找到")),
location
),
Times.Once
);
}
}
可组合工具设计 🧩
设计可以组合的工具
可组合的工具就像乐高积木,可以灵活组合成更复杂的功能。
python
# 🌟 可组合工具设计示例
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import asyncio
class ComposableTool(ABC):
"""可组合工具的基类"""
@abstractmethod
def get_name(self) -> str:
pass
@abstractmethod
def get_description(self) -> str:
pass
@abstractmethod
async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
pass
def get_output_schema(self) -> Dict[str, Any]:
"""定义工具输出的结构,便于其他工具使用"""
return {"type": "object"}
# 基础工具:数据获取
class DataFetchTool(ComposableTool):
def get_name(self) -> str:
return "data_fetch"
def get_description(self) -> str:
return "从指定数据源获取数据"
def get_output_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"data": {"type": "array"},
"metadata": {"type": "object"},
"fetch_time": {"type": "string"}
}
}
async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
source = parameters.get('source')
query = parameters.get('query', {})
# 模拟数据获取
await asyncio.sleep(0.1)
return {
"data": [
{"id": 1, "name": "产品A", "price": 100, "category": "电子产品"},
{"id": 2, "name": "产品B", "price": 200, "category": "服装"},
{"id": 3, "name": "产品C", "price": 150, "category": "电子产品"}
],
"metadata": {
"source": source,
"total_count": 3,
"query": query
},
"fetch_time": "2024-03-15T10:30:00Z"
}
# 基础工具:数据过滤
class DataFilterTool(ComposableTool):
def get_name(self) -> str:
return "data_filter"
def get_description(self) -> str:
return "根据条件过滤数据"
def get_output_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"filtered_data": {"type": "array"},
"filter_conditions": {"type": "object"},
"original_count": {"type": "integer"},
"filtered_count": {"type": "integer"}
}
}
async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
data = parameters.get('data', [])
conditions = parameters.get('conditions', {})
filtered_data = []
for item in data:
match = True
for key, value in conditions.items():
if key in item:
if isinstance(value, dict):
# 支持范围查询
if 'min' in value and item[key] < value['min']:
match = False
break
if 'max' in value and item[key] > value['max']:
match = False
break
else:
# 精确匹配
if item[key] != value:
match = False
break
if match:
filtered_data.append(item)
return {
"filtered_data": filtered_data,
"filter_conditions": conditions,
"original_count": len(data),
"filtered_count": len(filtered_data)
}
# 基础工具:数据分析
class DataAnalysisTool(ComposableTool):
def get_name(self) -> str:
return "data_analysis"
def get_description(self) -> str:
return "分析数据并生成统计信息"
def get_output_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"statistics": {"type": "object"},
"insights": {"type": "array"},
"analysis_type": {"type": "string"}
}
}
async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
data = parameters.get('data', [])
analysis_type = parameters.get('analysis_type', 'basic')
if not data:
return {
"statistics": {},
"insights": ["数据为空,无法进行分析"],
"analysis_type": analysis_type
}
# 基础统计
total_count = len(data)
# 价格分析(假设数据包含price字段)
prices = [item.get('price', 0) for item in data if 'price' in item]
price_stats = {}
if prices:
price_stats = {
"avg_price": sum(prices) / len(prices),
"min_price": min(prices),
"max_price": max(prices),
"total_value": sum(prices)
}
# 类别分析
categories = {}
for item in data:
category = item.get('category', 'Unknown')
categories[category] = categories.get(category, 0) + 1
# 生成洞察
insights = []
if price_stats:
insights.append(f"平均价格为 {price_stats['avg_price']:.2f}")
insights.append(f"价格范围:{price_stats['min_price']} - {price_stats['max_price']}")
if categories:
most_common_category = max(categories, key=categories.get)
insights.append(f"最常见的类别是:{most_common_category}({categories[most_common_category]}个)")
return {
"statistics": {
"total_count": total_count,
"price_stats": price_stats,
"category_distribution": categories
},
"insights": insights,
"analysis_type": analysis_type
}
# 工具组合器
class ToolComposer:
"""工具组合器 - 将多个工具组合成工作流"""
def __init__(self):
self.tools = {}
self.workflows = {}
def register_tool(self, tool: ComposableTool):
"""注册单个工具"""
self.tools[tool.get_name()] = tool
def register_workflow(self, name: str, workflow: List[Dict[str, Any]]):
"""
注册工作流
workflow格式:
[
{
"tool": "data_fetch",
"parameters": {"source": "products"},
"output_name": "raw_data"
},
{
"tool": "data_filter",
"parameters": {
"data": "$raw_data.data", # 引用前一步的输出
"conditions": {"category": "电子产品"}
},
"output_name": "filtered_data"
},
{
"tool": "data_analysis",
"parameters": {
"data": "$filtered_data.filtered_data",
"analysis_type": "detailed"
},
"output_name": "analysis_result"
}
]
"""
self.workflows[name] = workflow
async def execute_workflow(self, workflow_name: str, initial_parameters: Dict[str, Any] = None) -> Dict[str, Any]:
"""执行工作流"""
if workflow_name not in self.workflows:
raise ValueError(f"工作流 '{workflow_name}' 未找到")
workflow = self.workflows[workflow_name]
execution_context = initial_parameters or {}
step_results = {}
for step_index, step in enumerate(workflow):
tool_name = step['tool']
if tool_name not in self.tools:
raise ValueError(f"工具 '{tool_name}' 未找到")
# 解析参数中的引用
step_parameters = self._resolve_parameter_references(
step['parameters'], step_results, execution_context
)
# 执行工具
tool = self.tools[tool_name]
try:
result = await tool.execute(step_parameters)
# 保存结果
output_name = step.get('output_name', f'step_{step_index}')
step_results[output_name] = result
# 记录执行信息
print(f"✅ 步骤 {step_index + 1}: {tool_name} 执行成功")
except Exception as e:
print(f"❌ 步骤 {step_index + 1}: {tool_name} 执行失败: {str(e)}")
raise
return step_results
def _resolve_parameter_references(self, parameters: Dict[str, Any],
step_results: Dict[str, Any],
context: Dict[str, Any]) -> Dict[str, Any]:
"""解析参数中的引用"""
resolved = {}
for key, value in parameters.items():
if isinstance(value, str) and value.startswith('$'):
# 引用格式:$output_name.field_name
ref = value[1:] # 移除$符号
if '.' in ref:
output_name, field_path = ref.split('.', 1)
if output_name in step_results:
resolved[key] = self._get_nested_value(
step_results[output_name], field_path
)
elif output_name in context:
resolved[key] = self._get_nested_value(
context[output_name], field_path
)
else:
raise ValueError(f"引用 '{output_name}' 未找到")
else:
if ref in step_results:
resolved[key] = step_results[ref]
elif ref in context:
resolved[key] = context[ref]
else:
raise ValueError(f"引用 '{ref}' 未找到")
else:
resolved[key] = value
return resolved
def _get_nested_value(self, data: Any, field_path: str) -> Any:
"""获取嵌套字段的值"""
fields = field_path.split('.')
current = data
for field in fields:
if isinstance(current, dict) and field in current:
current = current[field]
else:
raise ValueError(f"字段路径 '{field_path}' 无效")
return current
# 使用示例
async def main():
# 创建工具组合器
composer = ToolComposer()
# 注册基础工具
composer.register_tool(DataFetchTool())
composer.register_tool(DataFilterTool())
composer.register_tool(DataAnalysisTool())
# 定义产品分析工作流
product_analysis_workflow = [
{
"tool": "data_fetch",
"parameters": {"source": "products_db"},
"output_name": "raw_data"
},
{
"tool": "data_filter",
"parameters": {
"data": "$raw_data.data",
"conditions": {"category": "电子产品"}
},
"output_name": "electronics_data"
},
{
"tool": "data_analysis",
"parameters": {
"data": "$electronics_data.filtered_data",
"analysis_type": "detailed"
},
"output_name": "electronics_analysis"
}
]
# 注册工作流
composer.register_workflow("product_analysis", product_analysis_workflow)
# 执行工作流
print("🚀 开始执行产品分析工作流...")
results = await composer.execute_workflow("product_analysis")
# 输出结果
print("\n📊 分析结果:")
final_analysis = results['electronics_analysis']
print(f"统计信息:{final_analysis['statistics']}")
print(f"洞察:{final_analysis['insights']}")
if __name__ == "__main__":
asyncio.run(main())
小结
工具设计的最佳实践可以总结为几个关键点:
🎯 核心原则
- 单一职责 - 一个工具只做一件事,但要做好
- 依赖注入 - 通过接口注入依赖,提高可测试性
- 参数验证 - 全面验证输入,防止错误和攻击
- 错误处理 - 统一、清晰、用户友好的错误处理
- 可组合性 - 设计可以组合的工具,支持复杂工作流
🛠️ 实践要点
- 使用清晰的接口设计
- 编写充分的单元测试
- 提供详细的文档和示例
- 考虑性能和安全性
- 支持可观测性(日志、监控)
🎪 记住:好的工具设计不是一蹴而就的,需要在实践中不断迭代和改进。从简单开始,逐步完善!
下一节:Schema设计艺术 - 学习如何设计清晰易用的参数规范