Skip to content

8.3 工具设计最佳实践 🔧

"好的工具应该像筷子一样 - 简单易用,但功能强大。"

想象一下,你手中有一把瑞士军刀和一把专业菜刀。当你需要切菜时,哪个更好用?答案显而易见 - 专业菜刀!MCP工具设计也是同样的道理:专注、高效、易用。

单一职责原则 🎯

为什么单一职责如此重要?

单一职责原则就像是"术业有专攻"的编程版本。一个工具只做一件事,但要把这件事做到极致。

python
# ❌ 反面教材:万能工具
class SuperTool:
    """这个工具什么都能做,但什么都做不好"""
    
    def execute(self, operation, **kwargs):
        if operation == "weather":
            # 天气查询逻辑
            return self.get_weather(kwargs.get('location'))
        elif operation == "calculator":
            # 计算器逻辑
            return self.calculate(kwargs.get('expression'))
        elif operation == "translate":
            # 翻译逻辑
            return self.translate(kwargs.get('text'), kwargs.get('target_lang'))
        elif operation == "file_read":
            # 文件读取逻辑
            return self.read_file(kwargs.get('path'))
        # ... 还有20种其他操作
        else:
            return "不支持的操作"
    
    # 问题:
    # 1. 代码混乱,难以维护
    # 2. 测试困难
    # 3. 单个功能出错影响所有功能
    # 4. 参数混乱,容易出错
python
# ✅ 正确示例:专注的工具
class WeatherQueryTool:
    """专门用于天气查询的工具 - 只做一件事,但做得很好"""
    
    def __init__(self, weather_service: WeatherService):
        self.weather_service = weather_service
        self.cache = WeatherCache(ttl=300)  # 5分钟缓存
        self.logger = logging.getLogger(__name__)
    
    def get_name(self) -> str:
        return "weather_query"
    
    def get_description(self) -> str:
        return "查询指定地区的天气信息,支持当前天气和未来7天预报"
    
    def get_schema(self) -> dict:
        return {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "城市名称或坐标,支持中英文",
                    "examples": ["北京", "New York", "40.7128,-74.0060"]
                },
                "forecast_days": {
                    "type": "integer",
                    "description": "预报天数(0为当前天气,1-7为未来天数)",
                    "minimum": 0,
                    "maximum": 7,
                    "default": 0
                },
                "units": {
                    "type": "string",
                    "description": "温度单位",
                    "enum": ["celsius", "fahrenheit"],
                    "default": "celsius"
                },
                "language": {
                    "type": "string", 
                    "description": "返回结果的语言",
                    "enum": ["zh", "en"],
                    "default": "zh"
                }
            },
            "required": ["location"]
        }
    
    async def execute(self, parameters: dict) -> dict:
        """
        执行天气查询
        
        专注于天气查询的核心逻辑:
        1. 参数验证
        2. 缓存检查
        3. 天气数据获取
        4. 结果格式化
        5. 错误处理
        """
        try:
            # 1. 参数验证(专门的验证逻辑)
            validated_params = self._validate_parameters(parameters)
            
            # 2. 检查缓存
            cache_key = self._generate_cache_key(validated_params)
            cached_result = await self.cache.get(cache_key)
            if cached_result:
                self.logger.info(f"缓存命中:{cache_key}")
                return cached_result
            
            # 3. 调用天气服务
            weather_data = await self.weather_service.get_weather(
                location=validated_params['location'],
                forecast_days=validated_params['forecast_days'], 
                units=validated_params['units']
            )
            
            # 4. 格式化结果(专门的格式化逻辑)
            formatted_result = self._format_weather_data(
                weather_data, validated_params['language']
            )
            
            # 5. 缓存结果
            await self.cache.set(cache_key, formatted_result)
            
            return formatted_result
            
        except LocationNotFoundError as e:
            raise ToolExecutionError(f"未找到位置:{e.location}")
        except WeatherServiceError as e:
            raise ToolExecutionError(f"天气服务错误:{str(e)}")
        except Exception as e:
            self.logger.error(f"天气查询意外错误:{str(e)}")
            raise ToolExecutionError("天气查询服务暂时不可用")
    
    def _validate_parameters(self, params: dict) -> dict:
        """专门的参数验证逻辑"""
        location = params.get('location', '').strip()
        if not location:
            raise ParameterError("location 参数不能为空")
        
        forecast_days = params.get('forecast_days', 0)
        if not isinstance(forecast_days, int) or forecast_days < 0 or forecast_days > 7:
            raise ParameterError("forecast_days 必须是0-7之间的整数")
        
        return {
            'location': location,
            'forecast_days': forecast_days,
            'units': params.get('units', 'celsius'),
            'language': params.get('language', 'zh')
        }
    
    def _format_weather_data(self, data: WeatherData, language: str) -> dict:
        """专门的数据格式化逻辑"""
        if language == 'zh':
            return {
                "位置": data.location,
                "当前温度": f"{data.current.temperature}°C",
                "天气描述": data.current.description_zh,
                "湿度": f"{data.current.humidity}%",
                "风速": f"{data.current.wind_speed} km/h",
                "更新时间": data.last_updated.isoformat(),
                "预报": [
                    {
                        "日期": forecast.date.strftime("%Y-%m-%d"),
                        "最高温": f"{forecast.high}°C",
                        "最低温": f"{forecast.low}°C", 
                        "天气": forecast.description_zh
                    }
                    for forecast in data.forecasts
                ] if data.forecasts else None
            }
        else:
            return {
                "location": data.location,
                "current_temperature": f"{data.current.temperature}°{'C' if data.units == 'celsius' else 'F'}",
                "description": data.current.description_en,
                "humidity": f"{data.current.humidity}%",
                "wind_speed": f"{data.current.wind_speed} km/h",
                "last_updated": data.last_updated.isoformat(),
                "forecast": [
                    {
                        "date": forecast.date.strftime("%Y-%m-%d"),
                        "high": f"{forecast.high}°{'C' if data.units == 'celsius' else 'F'}",
                        "low": f"{forecast.low}°{'C' if data.units == 'celsius' else 'F'}",
                        "description": forecast.description_en
                    }
                    for forecast in data.forecasts
                ] if data.forecasts else None
            }

单一职责的好处

  1. 易于理解 📖 - 代码逻辑清晰,新人也能快速上手
  2. 易于测试 🧪 - 测试场景明确,覆盖率容易达到
  3. 易于维护 🔧 - 修改某个功能不会影响其他功能
  4. 易于复用 ♻️ - 其他地方可以直接使用
  5. 易于扩展 📈 - 需要新功能时独立开发

一致的错误处理 🚨

建立错误处理体系

好的错误处理就像是系统的"免疫系统",能够优雅地处理各种异常情况。

typescript
// 🌟 错误处理框架
enum ErrorType {
    VALIDATION_ERROR = "VALIDATION_ERROR",
    AUTHENTICATION_ERROR = "AUTHENTICATION_ERROR", 
    AUTHORIZATION_ERROR = "AUTHORIZATION_ERROR",
    RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND",
    EXTERNAL_SERVICE_ERROR = "EXTERNAL_SERVICE_ERROR",
    RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED",
    INTERNAL_ERROR = "INTERNAL_ERROR"
}

class ToolError extends Error {
    public readonly errorType: ErrorType;
    public readonly userMessage: string;
    public readonly technicalDetails: string;
    public readonly errorCode: string;
    public readonly retryable: boolean;
    public readonly context: Record<string, any>;
    
    constructor(
        errorType: ErrorType,
        userMessage: string,
        technicalDetails: string,
        options: {
            errorCode?: string;
            retryable?: boolean;
            context?: Record<string, any>;
            cause?: Error;
        } = {}
    ) {
        super(userMessage);
        this.name = 'ToolError';
        this.errorType = errorType;
        this.userMessage = userMessage;
        this.technicalDetails = technicalDetails;
        this.errorCode = options.errorCode || this.generateErrorCode();
        this.retryable = options.retryable || false;
        this.context = options.context || {};
        
        if (options.cause) {
            this.cause = options.cause;
        }
    }
    
    private generateErrorCode(): string {
        return `${this.errorType}_${Date.now().toString(36).toUpperCase()}`;
    }
    
    toJSON() {
        return {
            errorType: this.errorType,
            userMessage: this.userMessage,
            technicalDetails: this.technicalDetails,
            errorCode: this.errorCode,
            retryable: this.retryable,
            context: this.context,
            timestamp: new Date().toISOString()
        };
    }
}

// 具体的错误类型
class ValidationError extends ToolError {
    constructor(message: string, field?: string, value?: any) {
        super(
            ErrorType.VALIDATION_ERROR,
            message,
            `参数验证失败:${field ? `字段 ${field}` : '未知字段'}`,
            {
                context: { field, value },
                retryable: false
            }
        );
    }
}

class ExternalServiceError extends ToolError {
    constructor(serviceName: string, originalError: Error, retryable: boolean = true) {
        super(
            ErrorType.EXTERNAL_SERVICE_ERROR,
            `${serviceName}服务暂时不可用,请稍后重试`,
            `外部服务调用失败:${originalError.message}`,
            {
                context: { serviceName, originalError: originalError.message },
                retryable,
                cause: originalError
            }
        );
    }
}

// 实际工具中的错误处理
class DatabaseQueryTool {
    private logger = new Logger('DatabaseQueryTool');
    
    async execute(parameters: any): Promise<any> {
        const startTime = Date.now();
        const executionId = this.generateExecutionId();
        
        try {
            // 记录执行开始
            this.logger.info('开始执行数据库查询', {
                executionId,
                parameters: this.sanitizeParametersForLogging(parameters)
            });
            
            // 参数验证
            const validatedParams = this.validateParameters(parameters);
            
            // 安全检查
            this.performSecurityCheck(validatedParams);
            
            // 执行查询
            const result = await this.executeQuery(validatedParams);
            
            // 记录成功
            const executionTime = Date.now() - startTime;
            this.logger.info('数据库查询执行成功', {
                executionId,
                executionTime,
                resultCount: Array.isArray(result) ? result.length : 1
            });
            
            return result;
            
        } catch (error) {
            const executionTime = Date.now() - startTime;
            
            // 根据错误类型进行分类处理
            const handledError = this.handleError(error, executionId, executionTime);
            
            // 记录错误
            this.logger.error('数据库查询执行失败', {
                executionId,
                executionTime,
                error: handledError.toJSON()
            });
            
            throw handledError;
        }
    }
    
    private handleError(error: any, executionId: string, executionTime: number): ToolError {
        // 如果已经是我们的错误类型,直接返回
        if (error instanceof ToolError) {
            return error;
        }
        
        // 根据错误类型分类处理
        if (error.name === 'ValidationError' || error.message.includes('validation')) {
            return new ValidationError(
                error.message || '参数验证失败',
                error.field,
                error.value
            );
        }
        
        if (error.code === 'ECONNREFUSED' || error.code === 'ETIMEDOUT') {
            return new ExternalServiceError(
                '数据库',
                error,
                true // 网络错误通常可以重试
            );
        }
        
        if (error.code === 'ER_ACCESS_DENIED_ERROR') {
            return new ToolError(
                ErrorType.AUTHORIZATION_ERROR,
                '数据库访问权限不足',
                `数据库认证失败:${error.message}`,
                { retryable: false }
            );
        }
        
        if (error.code === 'ER_NO_SUCH_TABLE') {
            return new ToolError(
                ErrorType.RESOURCE_NOT_FOUND,
                '请求的数据表不存在',
                `表不存在:${error.message}`,
                { retryable: false }
            );
        }
        
        // 未知错误,包装为内部错误
        return new ToolError(
            ErrorType.INTERNAL_ERROR,
            '系统内部错误,请联系管理员',
            `意外错误:${error.message}`,
            {
                context: { executionId, executionTime },
                retryable: false,
                cause: error
            }
        );
    }
    
    private validateParameters(params: any): any {
        if (!params.query || typeof params.query !== 'string') {
            throw new ValidationError(
                'query参数是必需的,且必须是字符串类型',
                'query',
                params.query
            );
        }
        
        if (params.query.length > 10000) {
            throw new ValidationError(
                'SQL查询长度不能超过10000个字符',
                'query',
                `长度:${params.query.length}`
            );
        }
        
        return params;
    }
    
    private performSecurityCheck(params: any): void {
        const query = params.query.toLowerCase();
        
        // 检查危险的SQL操作
        const dangerousPatterns = [
            /drop\s+table/i,
            /delete\s+from.*where\s+1\s*=\s*1/i,
            /truncate\s+table/i,
            /alter\s+table/i
        ];
        
        for (const pattern of dangerousPatterns) {
            if (pattern.test(query)) {
                throw new ToolError(
                    ErrorType.AUTHORIZATION_ERROR,
                    '查询包含危险操作,已被阻止',
                    `检测到危险SQL模式:${pattern}`,
                    { retryable: false }
                );
            }
        }
    }
    
    private sanitizeParametersForLogging(params: any): any {
        // 移除敏感信息用于日志记录
        const sanitized = { ...params };
        
        if (sanitized.password) sanitized.password = '***';
        if (sanitized.apiKey) sanitized.apiKey = '***';
        if (sanitized.token) sanitized.token = '***';
        
        return sanitized;
    }
}

参数验证最佳实践 ✅

全面的参数验证体系

参数验证就像是系统的"质检员",确保进入系统的每个参数都符合要求。

java
// 🌟 Java参数验证框架
public class ParameterValidator {
    private static final Logger logger = LoggerFactory.getLogger(ParameterValidator.class);
    
    public static class ValidationRule<T> {
        private final String fieldName;
        private final Predicate<T> validator;
        private final String errorMessage;
        private final boolean required;
        
        public ValidationRule(String fieldName, Predicate<T> validator, 
                            String errorMessage, boolean required) {
            this.fieldName = fieldName;
            this.validator = validator;
            this.errorMessage = errorMessage;
            this.required = required;
        }
        
        // Getters...
    }
    
    public static class ValidationResult {
        private final boolean isValid;
        private final List<String> errors;
        private final Map<String, Object> sanitizedValues;
        
        // 构造函数和方法...
    }
    
    public static class ValidatorBuilder {
        private final Map<String, List<ValidationRule<?>>> rules = new HashMap<>();
        
        public ValidatorBuilder field(String fieldName) {
            this.currentField = fieldName;
            return this;
        }
        
        public ValidatorBuilder required() {
            getCurrentRules().add(new ValidationRule<>(
                currentField,
                Objects::nonNull,
                "字段 " + currentField + " 是必需的",
                true
            ));
            return this;
        }
        
        public ValidatorBuilder string() {
            getCurrentRules().add(new ValidationRule<>(
                currentField,
                value -> value instanceof String,
                "字段 " + currentField + " 必须是字符串类型",
                false
            ));
            return this;
        }
        
        public ValidatorBuilder minLength(int minLength) {
            getCurrentRules().add(new ValidationRule<String>(
                currentField,
                value -> value.length() >= minLength,
                String.format("字段 %s 长度不能少于 %d 个字符", currentField, minLength),
                false
            ));
            return this;
        }
        
        public ValidatorBuilder maxLength(int maxLength) {
            getCurrentRules().add(new ValidationRule<String>(
                currentField,
                value -> value.length() <= maxLength,
                String.format("字段 %s 长度不能超过 %d 个字符", currentField, maxLength),
                false
            ));
            return this;
        }
        
        public ValidatorBuilder pattern(String regex, String message) {
            Pattern pattern = Pattern.compile(regex);
            getCurrentRules().add(new ValidationRule<String>(
                currentField,
                value -> pattern.matcher(value).matches(),
                message != null ? message : 
                    String.format("字段 %s 格式不正确", currentField),
                false
            ));
            return this;
        }
        
        public ValidatorBuilder email() {
            return pattern(
                "^[A-Za-z0-9+_.-]+@([A-Za-z0-9.-]+\\.[A-Za-z]{2,})$",
                "请输入有效的邮箱地址"
            );
        }
        
        public ValidatorBuilder integer() {
            getCurrentRules().add(new ValidationRule<>(
                currentField,
                value -> value instanceof Integer || 
                        (value instanceof String && value.toString().matches("-?\\d+")),
                "字段 " + currentField + " 必须是整数",
                false
            ));
            return this;
        }
        
        public ValidatorBuilder range(int min, int max) {
            getCurrentRules().add(new ValidationRule<Integer>(
                currentField,
                value -> value >= min && value <= max,
                String.format("字段 %s 必须在 %d 到 %d 之间", currentField, min, max),
                false
            ));
            return this;
        }
        
        public ValidatorBuilder custom(Predicate<Object> validator, String errorMessage) {
            getCurrentRules().add(new ValidationRule<>(
                currentField,
                validator,
                errorMessage,
                false
            ));
            return this;
        }
        
        public ParameterValidator build() {
            return new ParameterValidator(this.rules);
        }
    }
    
    // 实际的文件操作工具验证示例
    public static ParameterValidator createFileOperationValidator() {
        return new ValidatorBuilder()
            .field("operation")
                .required()
                .string()
                .custom(value -> Arrays.asList("read", "write", "delete", "list").contains(value),
                       "operation 必须是 read, write, delete, list 中的一个")
            .field("path")  
                .required()
                .string()
                .minLength(1)
                .maxLength(500)
                .custom(ParameterValidator::isValidFilePath, "文件路径包含非法字符")
                .custom(ParameterValidator::isPathWithinAllowedDirectories, "文件路径超出允许范围")
            .field("content")
                .string()
                .maxLength(1024 * 1024) // 1MB限制
                .custom(value -> {
                    // 只有写操作才需要content
                    return true; // 这里应该在上下文中检查operation
                }, "写操作需要提供content参数")
            .field("encoding")
                .string()
                .custom(value -> Arrays.asList("utf-8", "gbk", "ascii").contains(value.toLowerCase()),
                       "encoding 必须是 utf-8, gbk, ascii 中的一个")
            .build();
    }
    
    private static boolean isValidFilePath(Object path) {
        if (!(path instanceof String)) return false;
        String pathStr = (String) path;
        
        // 检查危险字符
        String[] dangerousPatterns = {"../", "..\\", "<", ">", "|", "?", "*"};
        for (String pattern : dangerousPatterns) {
            if (pathStr.contains(pattern)) return false;
        }
        
        return true;
    }
    
    private static boolean isPathWithinAllowedDirectories(Object path) {
        if (!(path instanceof String)) return false;
        String pathStr = (String) path;
        
        try {
            Path normalizedPath = Paths.get(pathStr).normalize();
            String normalizedStr = normalizedPath.toString();
            
            // 检查是否在允许的目录内
            List<String> allowedDirectories = Arrays.asList(
                "/tmp/mcp/",
                "/var/mcp/data/",
                "C:\\temp\\mcp\\",
                "C:\\data\\mcp\\"
            );
            
            return allowedDirectories.stream()
                .anyMatch(allowedDir -> normalizedStr.startsWith(allowedDir));
                
        } catch (Exception e) {
            return false;
        }
    }
}

// 使用示例
public class FileOperationTool implements Tool {
    private final ParameterValidator validator;
    
    public FileOperationTool() {
        this.validator = ParameterValidator.createFileOperationValidator();
    }
    
    @Override
    public ToolResponse execute(ToolRequest request) {
        // 验证参数
        ValidationResult validation = validator.validate(request.getParameters());
        
        if (!validation.isValid()) {
            throw new ToolExecutionException(
                "参数验证失败:" + String.join(", ", validation.getErrors())
            );
        }
        
        // 使用验证和清理后的参数
        Map<String, Object> safeParams = validation.getSanitizedValues();
        
        String operation = (String) safeParams.get("operation");
        String path = (String) safeParams.get("path");
        
        // 继续执行工具逻辑...
        switch (operation) {
            case "read":
                return executeRead(path);
            case "write":
                String content = (String) safeParams.get("content");
                return executeWrite(path, content);
            // ... 其他操作
        }
    }
}

依赖注入与可测试性 🧪

让工具变得可测试

可测试的代码就像是透明的玻璃房子,里面的每个部分都能清楚地看到和验证。

csharp
// 🌟 可测试的工具设计
public interface IWeatherService
{
    Task<WeatherData> GetWeatherAsync(string location, int days = 1);
}

public interface ICacheService  
{
    Task<T> GetAsync<T>(string key);
    Task SetAsync<T>(string key, T value, TimeSpan expiry);
}

public interface ILogger<T>
{
    void LogInformation(string message, params object[] args);
    void LogError(Exception exception, string message, params object[] args);
}

// 主要工具类 - 注意所有依赖都是通过接口注入的
public class WeatherForecastTool : ITool
{
    private readonly IWeatherService _weatherService;
    private readonly ICacheService _cacheService; 
    private readonly ILogger<WeatherForecastTool> _logger;
    
    // 🎯 关键:通过构造函数注入所有依赖
    public WeatherForecastTool(
        IWeatherService weatherService,
        ICacheService cacheService,
        ILogger<WeatherForecastTool> logger)
    {
        _weatherService = weatherService ?? throw new ArgumentNullException(nameof(weatherService));
        _cacheService = cacheService ?? throw new ArgumentNullException(nameof(cacheService));
        _logger = logger ?? throw new ArgumentNullException(nameof(logger));
    }
    
    public string Name => "weather_forecast";
    public string Description => "获取天气预报信息";
    
    public async Task<ToolResponse> ExecuteAsync(ToolRequest request)
    {
        var location = request.Parameters["location"].ToString();
        var days = request.Parameters.ContainsKey("days") 
            ? Convert.ToInt32(request.Parameters["days"]) 
            : 1;
        
        _logger.LogInformation("开始获取天气预报:{Location}, {Days}天", location, days);
        
        try
        {
            // 检查缓存
            var cacheKey = $"weather:{location}:{days}";
            var cachedResult = await _cacheService.GetAsync<WeatherData>(cacheKey);
            
            if (cachedResult != null)
            {
                _logger.LogInformation("使用缓存的天气数据:{Location}", location);
                return CreateResponse(cachedResult);
            }
            
            // 从服务获取数据
            var weatherData = await _weatherService.GetWeatherAsync(location, days);
            
            // 缓存结果
            await _cacheService.SetAsync(cacheKey, weatherData, TimeSpan.FromMinutes(30));
            
            _logger.LogInformation("成功获取天气预报:{Location}", location);
            return CreateResponse(weatherData);
        }
        catch (LocationNotFoundException ex)
        {
            _logger.LogError(ex, "位置未找到:{Location}", location);
            throw new ToolExecutionException($"未找到位置:{location}");
        }
        catch (Exception ex)
        {
            _logger.LogError(ex, "获取天气预报时发生错误:{Location}", location);
            throw new ToolExecutionException("天气服务暂时不可用");
        }
    }
    
    private ToolResponse CreateResponse(WeatherData data)
    {
        return new ToolResponse
        {
            Content = new List<ContentItem>
            {
                new TextContent(JsonSerializer.Serialize(data))
            }
        };
    }
}

// 🧪 单元测试 - 现在变得非常简单
[TestClass]
public class WeatherForecastToolTests
{
    private Mock<IWeatherService> _mockWeatherService;
    private Mock<ICacheService> _mockCacheService;
    private Mock<ILogger<WeatherForecastTool>> _mockLogger;
    private WeatherForecastTool _tool;
    
    [TestInitialize]
    public void Setup()
    {
        _mockWeatherService = new Mock<IWeatherService>();
        _mockCacheService = new Mock<ICacheService>(); 
        _mockLogger = new Mock<ILogger<WeatherForecastTool>>();
        
        _tool = new WeatherForecastTool(
            _mockWeatherService.Object,
            _mockCacheService.Object,
            _mockLogger.Object
        );
    }
    
    [TestMethod]
    public async Task ExecuteAsync_ValidLocation_ReturnsWeatherData()
    {
        // Arrange
        var location = "北京";
        var expectedWeatherData = new WeatherData
        {
            Location = location,
            Temperature = 25,
            Description = "晴朗"
        };
        
        _mockCacheService
            .Setup(c => c.GetAsync<WeatherData>(It.IsAny<string>()))
            .ReturnsAsync((WeatherData)null); // 模拟缓存未命中
            
        _mockWeatherService
            .Setup(w => w.GetWeatherAsync(location, 1))
            .ReturnsAsync(expectedWeatherData);
        
        var request = new ToolRequest
        {
            Parameters = new Dictionary<string, object>
            {
                ["location"] = location
            }
        };
        
        // Act
        var response = await _tool.ExecuteAsync(request);
        
        // Assert
        Assert.IsNotNull(response);
        Assert.AreEqual(1, response.Content.Count);
        
        var responseData = JsonSerializer.Deserialize<WeatherData>(
            response.Content[0].ToString()
        );
        Assert.AreEqual(location, responseData.Location);
        Assert.AreEqual(25, responseData.Temperature);
        
        // 验证依赖调用
        _mockWeatherService.Verify(
            w => w.GetWeatherAsync(location, 1), 
            Times.Once
        );
        
        _mockCacheService.Verify(
            c => c.SetAsync(It.IsAny<string>(), expectedWeatherData, It.IsAny<TimeSpan>()),
            Times.Once
        );
    }
    
    [TestMethod]
    public async Task ExecuteAsync_CacheHit_ReturnsCachedData()
    {
        // Arrange
        var location = "上海";
        var cachedWeatherData = new WeatherData
        {
            Location = location,
            Temperature = 20,
            Description = "多云"
        };
        
        _mockCacheService
            .Setup(c => c.GetAsync<WeatherData>(It.IsAny<string>()))
            .ReturnsAsync(cachedWeatherData); // 模拟缓存命中
        
        var request = new ToolRequest
        {
            Parameters = new Dictionary<string, object>
            {
                ["location"] = location
            }
        };
        
        // Act
        var response = await _tool.ExecuteAsync(request);
        
        // Assert
        Assert.IsNotNull(response);
        
        var responseData = JsonSerializer.Deserialize<WeatherData>(
            response.Content[0].ToString()
        );
        Assert.AreEqual(location, responseData.Location);
        Assert.AreEqual(20, responseData.Temperature);
        
        // 验证没有调用天气服务(因为缓存命中)
        _mockWeatherService.Verify(
            w => w.GetWeatherAsync(It.IsAny<string>(), It.IsAny<int>()), 
            Times.Never
        );
    }
    
    [TestMethod]
    public async Task ExecuteAsync_ServiceThrowsException_ThrowsToolExecutionException()
    {
        // Arrange
        var location = "不存在的城市";
        
        _mockCacheService
            .Setup(c => c.GetAsync<WeatherData>(It.IsAny<string>()))
            .ReturnsAsync((WeatherData)null);
            
        _mockWeatherService
            .Setup(w => w.GetWeatherAsync(location, It.IsAny<int>()))
            .ThrowsAsync(new LocationNotFoundException(location));
        
        var request = new ToolRequest
        {
            Parameters = new Dictionary<string, object>
            {
                ["location"] = location
            }
        };
        
        // Act & Assert
        var exception = await Assert.ThrowsExceptionAsync<ToolExecutionException>(
            () => _tool.ExecuteAsync(request)
        );
        
        Assert.IsTrue(exception.Message.Contains(location));
        
        // 验证日志记录
        _mockLogger.Verify(
            l => l.LogError(
                It.IsAny<Exception>(), 
                It.Is<string>(s => s.Contains("位置未找到")), 
                location
            ),
            Times.Once
        );
    }
}

可组合工具设计 🧩

设计可以组合的工具

可组合的工具就像乐高积木,可以灵活组合成更复杂的功能。

python
# 🌟 可组合工具设计示例
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import asyncio

class ComposableTool(ABC):
    """可组合工具的基类"""
    
    @abstractmethod
    def get_name(self) -> str:
        pass
    
    @abstractmethod 
    def get_description(self) -> str:
        pass
    
    @abstractmethod
    async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
        pass
    
    def get_output_schema(self) -> Dict[str, Any]:
        """定义工具输出的结构,便于其他工具使用"""
        return {"type": "object"}

# 基础工具:数据获取
class DataFetchTool(ComposableTool):
    def get_name(self) -> str:
        return "data_fetch"
    
    def get_description(self) -> str:
        return "从指定数据源获取数据"
    
    def get_output_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "data": {"type": "array"},
                "metadata": {"type": "object"},
                "fetch_time": {"type": "string"}
            }
        }
    
    async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
        source = parameters.get('source')
        query = parameters.get('query', {})
        
        # 模拟数据获取
        await asyncio.sleep(0.1)
        
        return {
            "data": [
                {"id": 1, "name": "产品A", "price": 100, "category": "电子产品"},
                {"id": 2, "name": "产品B", "price": 200, "category": "服装"},
                {"id": 3, "name": "产品C", "price": 150, "category": "电子产品"}
            ],
            "metadata": {
                "source": source,
                "total_count": 3,
                "query": query
            },
            "fetch_time": "2024-03-15T10:30:00Z"
        }

# 基础工具:数据过滤
class DataFilterTool(ComposableTool):
    def get_name(self) -> str:
        return "data_filter"
    
    def get_description(self) -> str:
        return "根据条件过滤数据"
    
    def get_output_schema(self) -> Dict[str, Any]:
        return {
            "type": "object", 
            "properties": {
                "filtered_data": {"type": "array"},
                "filter_conditions": {"type": "object"},
                "original_count": {"type": "integer"},
                "filtered_count": {"type": "integer"}
            }
        }
    
    async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
        data = parameters.get('data', [])
        conditions = parameters.get('conditions', {})
        
        filtered_data = []
        
        for item in data:
            match = True
            for key, value in conditions.items():
                if key in item:
                    if isinstance(value, dict):
                        # 支持范围查询
                        if 'min' in value and item[key] < value['min']:
                            match = False
                            break
                        if 'max' in value and item[key] > value['max']:
                            match = False
                            break
                    else:
                        # 精确匹配
                        if item[key] != value:
                            match = False
                            break
            
            if match:
                filtered_data.append(item)
        
        return {
            "filtered_data": filtered_data,
            "filter_conditions": conditions,
            "original_count": len(data),
            "filtered_count": len(filtered_data)
        }

# 基础工具:数据分析
class DataAnalysisTool(ComposableTool):
    def get_name(self) -> str:
        return "data_analysis"
    
    def get_description(self) -> str:
        return "分析数据并生成统计信息"
    
    def get_output_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "statistics": {"type": "object"},
                "insights": {"type": "array"},
                "analysis_type": {"type": "string"}
            }
        }
    
    async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
        data = parameters.get('data', [])
        analysis_type = parameters.get('analysis_type', 'basic')
        
        if not data:
            return {
                "statistics": {},
                "insights": ["数据为空,无法进行分析"],
                "analysis_type": analysis_type
            }
        
        # 基础统计
        total_count = len(data)
        
        # 价格分析(假设数据包含price字段)
        prices = [item.get('price', 0) for item in data if 'price' in item]
        price_stats = {}
        
        if prices:
            price_stats = {
                "avg_price": sum(prices) / len(prices),
                "min_price": min(prices),
                "max_price": max(prices),
                "total_value": sum(prices)
            }
        
        # 类别分析
        categories = {}
        for item in data:
            category = item.get('category', 'Unknown')
            categories[category] = categories.get(category, 0) + 1
        
        # 生成洞察
        insights = []
        if price_stats:
            insights.append(f"平均价格为 {price_stats['avg_price']:.2f}")
            insights.append(f"价格范围:{price_stats['min_price']} - {price_stats['max_price']}")
        
        if categories:
            most_common_category = max(categories, key=categories.get)
            insights.append(f"最常见的类别是:{most_common_category}{categories[most_common_category]}个)")
        
        return {
            "statistics": {
                "total_count": total_count,
                "price_stats": price_stats,
                "category_distribution": categories
            },
            "insights": insights,
            "analysis_type": analysis_type
        }

# 工具组合器
class ToolComposer:
    """工具组合器 - 将多个工具组合成工作流"""
    
    def __init__(self):
        self.tools = {}
        self.workflows = {}
    
    def register_tool(self, tool: ComposableTool):
        """注册单个工具"""
        self.tools[tool.get_name()] = tool
    
    def register_workflow(self, name: str, workflow: List[Dict[str, Any]]):
        """
        注册工作流
        
        workflow格式:
        [
            {
                "tool": "data_fetch", 
                "parameters": {"source": "products"},
                "output_name": "raw_data"
            },
            {
                "tool": "data_filter",
                "parameters": {
                    "data": "$raw_data.data",  # 引用前一步的输出
                    "conditions": {"category": "电子产品"}
                },
                "output_name": "filtered_data"
            },
            {
                "tool": "data_analysis",
                "parameters": {
                    "data": "$filtered_data.filtered_data",
                    "analysis_type": "detailed"
                },
                "output_name": "analysis_result"
            }
        ]
        """
        self.workflows[name] = workflow
    
    async def execute_workflow(self, workflow_name: str, initial_parameters: Dict[str, Any] = None) -> Dict[str, Any]:
        """执行工作流"""
        if workflow_name not in self.workflows:
            raise ValueError(f"工作流 '{workflow_name}' 未找到")
        
        workflow = self.workflows[workflow_name]
        execution_context = initial_parameters or {}
        step_results = {}
        
        for step_index, step in enumerate(workflow):
            tool_name = step['tool']
            if tool_name not in self.tools:
                raise ValueError(f"工具 '{tool_name}' 未找到")
            
            # 解析参数中的引用
            step_parameters = self._resolve_parameter_references(
                step['parameters'], step_results, execution_context
            )
            
            # 执行工具
            tool = self.tools[tool_name]
            try:
                result = await tool.execute(step_parameters)
                
                # 保存结果
                output_name = step.get('output_name', f'step_{step_index}')
                step_results[output_name] = result
                
                # 记录执行信息
                print(f"✅ 步骤 {step_index + 1}: {tool_name} 执行成功")
                
            except Exception as e:
                print(f"❌ 步骤 {step_index + 1}: {tool_name} 执行失败: {str(e)}")
                raise
        
        return step_results
    
    def _resolve_parameter_references(self, parameters: Dict[str, Any], 
                                    step_results: Dict[str, Any], 
                                    context: Dict[str, Any]) -> Dict[str, Any]:
        """解析参数中的引用"""
        resolved = {}
        
        for key, value in parameters.items():
            if isinstance(value, str) and value.startswith('$'):
                # 引用格式:$output_name.field_name
                ref = value[1:]  # 移除$符号
                
                if '.' in ref:
                    output_name, field_path = ref.split('.', 1)
                    if output_name in step_results:
                        resolved[key] = self._get_nested_value(
                            step_results[output_name], field_path
                        )
                    elif output_name in context:
                        resolved[key] = self._get_nested_value(
                            context[output_name], field_path
                        )
                    else:
                        raise ValueError(f"引用 '{output_name}' 未找到")
                else:
                    if ref in step_results:
                        resolved[key] = step_results[ref]
                    elif ref in context:
                        resolved[key] = context[ref]
                    else:
                        raise ValueError(f"引用 '{ref}' 未找到")
            else:
                resolved[key] = value
        
        return resolved
    
    def _get_nested_value(self, data: Any, field_path: str) -> Any:
        """获取嵌套字段的值"""
        fields = field_path.split('.')
        current = data
        
        for field in fields:
            if isinstance(current, dict) and field in current:
                current = current[field]
            else:
                raise ValueError(f"字段路径 '{field_path}' 无效")
        
        return current

# 使用示例
async def main():
    # 创建工具组合器
    composer = ToolComposer()
    
    # 注册基础工具
    composer.register_tool(DataFetchTool())
    composer.register_tool(DataFilterTool())
    composer.register_tool(DataAnalysisTool())
    
    # 定义产品分析工作流
    product_analysis_workflow = [
        {
            "tool": "data_fetch",
            "parameters": {"source": "products_db"},
            "output_name": "raw_data"
        },
        {
            "tool": "data_filter", 
            "parameters": {
                "data": "$raw_data.data",
                "conditions": {"category": "电子产品"}
            },
            "output_name": "electronics_data"
        },
        {
            "tool": "data_analysis",
            "parameters": {
                "data": "$electronics_data.filtered_data",
                "analysis_type": "detailed"
            },
            "output_name": "electronics_analysis"
        }
    ]
    
    # 注册工作流
    composer.register_workflow("product_analysis", product_analysis_workflow)
    
    # 执行工作流
    print("🚀 开始执行产品分析工作流...")
    results = await composer.execute_workflow("product_analysis")
    
    # 输出结果
    print("\n📊 分析结果:")
    final_analysis = results['electronics_analysis']
    print(f"统计信息:{final_analysis['statistics']}")
    print(f"洞察:{final_analysis['insights']}")

if __name__ == "__main__":
    asyncio.run(main())

小结

工具设计的最佳实践可以总结为几个关键点:

🎯 核心原则

  1. 单一职责 - 一个工具只做一件事,但要做好
  2. 依赖注入 - 通过接口注入依赖,提高可测试性
  3. 参数验证 - 全面验证输入,防止错误和攻击
  4. 错误处理 - 统一、清晰、用户友好的错误处理
  5. 可组合性 - 设计可以组合的工具,支持复杂工作流

🛠️ 实践要点

  • 使用清晰的接口设计
  • 编写充分的单元测试
  • 提供详细的文档和示例
  • 考虑性能和安全性
  • 支持可观测性(日志、监控)

🎪 记住:好的工具设计不是一蹴而就的,需要在实践中不断迭代和改进。从简单开始,逐步完善!


下一节Schema设计艺术 - 学习如何设计清晰易用的参数规范