Skip to content

5.1 资源管理系统

🎯 学习目标:掌握MCP的资源管理机制,构建动态、高效的资源管理系统
⏱️ 预计时间:45分钟
📊 难度等级:⭐⭐⭐⭐

🔍 什么是MCP资源?

在MCP协议中,资源(Resources)是一个核心概念。它们是AI模型可以读取和分析的结构化数据源,包括文件、数据库记录、API响应等。

🎭 资源 vs 工具的区别

想象一下图书馆的场景:

📚 资源 = 图书馆里的书籍(静态信息源)
🔧 工具 = 图书管理员的服务(动态操作)

资源特点:

  • 📖 静态内容:提供信息给AI模型阅读
  • 🔗 URI标识:每个资源都有唯一标识符
  • 📊 结构化数据:有明确的格式和类型
  • 🔄 可更新:内容可以动态变化

工具特点:

  • 动态操作:执行特定的任务
  • 📥 接受参数:根据输入产生输出
  • 🎯 执行动作:修改状态或产生副作用

🏗️ 资源管理系统架构

🎯 核心组件实现

📋 资源基类定义

首先,让我们定义资源管理系统的核心接口:

python
"""
resource_manager.py - MCP资源管理系统
"""

from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Union, AsyncIterator
from dataclasses import dataclass, field
from enum import Enum
import asyncio
import json
import hashlib
from datetime import datetime, timedelta
from urllib.parse import urlparse
from pathlib import Path
import aiofiles
import aiohttp
from loguru import logger

class ResourceType(Enum):
    """资源类型枚举"""
    TEXT = "text"
    JSON = "json" 
    BINARY = "binary"
    IMAGE = "image"
    VIDEO = "video"
    AUDIO = "audio"
    DATABASE = "database"
    API = "api"

class ResourceStatus(Enum):
    """资源状态枚举"""
    AVAILABLE = "available"
    LOADING = "loading"
    ERROR = "error"
    EXPIRED = "expired"
    UPDATING = "updating"

@dataclass
class ResourceMetadata:
    """资源元数据"""
    uri: str
    name: str
    description: str
    resource_type: ResourceType
    mime_type: str
    size: Optional[int] = None
    created_at: datetime = field(default_factory=datetime.now)
    updated_at: datetime = field(default_factory=datetime.now)
    expires_at: Optional[datetime] = None
    version: str = "1.0"
    tags: List[str] = field(default_factory=list)
    permissions: Dict[str, bool] = field(default_factory=dict)
    checksum: Optional[str] = None

@dataclass 
class ResourceContent:
    """资源内容"""
    data: Union[str, bytes, Dict[str, Any]]
    metadata: ResourceMetadata
    status: ResourceStatus = ResourceStatus.AVAILABLE
    error_message: Optional[str] = None

class ResourceProvider(ABC):
    """资源提供者抽象基类"""
    
    def __init__(self, name: str):
        self.name = name
        self.supported_schemes = set()
    
    @abstractmethod
    async def supports_uri(self, uri: str) -> bool:
        """检查是否支持指定URI"""
        pass
    
    @abstractmethod
    async def read_resource(self, uri: str) -> ResourceContent:
        """读取资源内容"""
        pass
    
    @abstractmethod
    async def list_resources(self, pattern: Optional[str] = None) -> List[ResourceMetadata]:
        """列出可用资源"""
        pass
    
    async def watch_resource(self, uri: str) -> AsyncIterator[ResourceContent]:
        """监控资源变化(可选实现)"""
        # 默认实现:定期检查
        last_checksum = None
        while True:
            try:
                content = await self.read_resource(uri)
                current_checksum = content.metadata.checksum
                
                if current_checksum != last_checksum:
                    last_checksum = current_checksum
                    yield content
                    
                await asyncio.sleep(30)  # 30秒检查一次
            except Exception as e:
                logger.error(f"监控资源失败 {uri}: {e}")
                await asyncio.sleep(60)  # 错误时等待更长时间

class ResourceRegistry:
    """资源注册表"""
    
    def __init__(self):
        self.providers: Dict[str, ResourceProvider] = {}
        self.resource_cache: Dict[str, ResourceContent] = {}
        self.metadata_cache: Dict[str, ResourceMetadata] = {}
        self.uri_mappings: Dict[str, str] = {}  # 别名映射
    
    def register_provider(self, provider: ResourceProvider):
        """注册资源提供者"""
        self.providers[provider.name] = provider
        logger.info(f"注册资源提供者: {provider.name}")
    
    async def find_provider(self, uri: str) -> Optional[ResourceProvider]:
        """根据URI找到合适的提供者"""
        # 检查别名映射
        actual_uri = self.uri_mappings.get(uri, uri)
        
        for provider in self.providers.values():
            if await provider.supports_uri(actual_uri):
                return provider
        
        return None
    
    def add_uri_mapping(self, alias: str, actual_uri: str):
        """添加URI别名映射"""
        self.uri_mappings[alias] = actual_uri
        logger.debug(f"添加URI映射: {alias} -> {actual_uri}")

class ResourceManager:
    """资源管理器主类"""
    
    def __init__(self, cache_ttl: int = 3600):
        self.registry = ResourceRegistry()
        self.cache_ttl = cache_ttl
        self.active_watchers: Dict[str, asyncio.Task] = {}
        
    async def read_resource(self, uri: str, use_cache: bool = True) -> ResourceContent:
        """读取资源"""
        try:
            # 检查缓存
            if use_cache and uri in self.registry.resource_cache:
                cached = self.registry.resource_cache[uri]
                if not self._is_cache_expired(cached):
                    logger.debug(f"资源缓存命中: {uri}")
                    return cached
            
            # 查找提供者
            provider = await self.registry.find_provider(uri)
            if not provider:
                raise ValueError(f"未找到支持URI的提供者: {uri}")
            
            # 读取资源
            logger.debug(f"从提供者 {provider.name} 读取资源: {uri}")
            content = await provider.read_resource(uri)
            
            # 更新缓存
            if use_cache:
                self.registry.resource_cache[uri] = content
                self.registry.metadata_cache[uri] = content.metadata
            
            return content
            
        except Exception as e:
            logger.error(f"读取资源失败 {uri}: {e}")
            return ResourceContent(
                data="",
                metadata=ResourceMetadata(
                    uri=uri,
                    name=Path(uri).name,
                    description=f"读取失败: {str(e)}",
                    resource_type=ResourceType.TEXT,
                    mime_type="text/plain"
                ),
                status=ResourceStatus.ERROR,
                error_message=str(e)
            )
    
    async def list_resources(self, pattern: Optional[str] = None) -> List[ResourceMetadata]:
        """列出所有可用资源"""
        all_resources = []
        
        for provider in self.registry.providers.values():
            try:
                resources = await provider.list_resources(pattern)
                all_resources.extend(resources)
            except Exception as e:
                logger.error(f"提供者 {provider.name} 列出资源失败: {e}")
        
        return all_resources
    
    async def watch_resource(self, uri: str, callback) -> str:
        """监控资源变化"""
        if uri in self.active_watchers:
            return f"watch_{uri}"
        
        provider = await self.registry.find_provider(uri)
        if not provider:
            raise ValueError(f"未找到支持URI的提供者: {uri}")
        
        async def watch_loop():
            async for content in provider.watch_resource(uri):
                await callback(content)
        
        task = asyncio.create_task(watch_loop())
        watch_id = f"watch_{uri}_{id(task)}"
        self.active_watchers[watch_id] = task
        
        return watch_id
    
    def stop_watching(self, watch_id: str):
        """停止监控资源"""
        if watch_id in self.active_watchers:
            self.active_watchers[watch_id].cancel()
            del self.active_watchers[watch_id]
            logger.info(f"停止监控: {watch_id}")
    
    def _is_cache_expired(self, content: ResourceContent) -> bool:
        """检查缓存是否过期"""
        if content.metadata.expires_at:
            return datetime.now() > content.metadata.expires_at
        
        # 使用全局TTL
        age = datetime.now() - content.metadata.updated_at
        return age.total_seconds() > self.cache_ttl
    
    def clear_cache(self, uri: Optional[str] = None):
        """清理缓存"""
        if uri:
            self.registry.resource_cache.pop(uri, None)
            self.registry.metadata_cache.pop(uri, None)
        else:
            self.registry.resource_cache.clear()
            self.registry.metadata_cache.clear()
        
        logger.info(f"缓存已清理: {uri or '全部'}")

📁 文件资源提供者

🗂️ 本地文件系统提供者

python
"""
file_provider.py - 文件资源提供者
"""

import os
import mimetypes
from pathlib import Path
from typing import List, Optional, Dict, Any
import aiofiles
import hashlib
from datetime import datetime

from .resource_manager import (
    ResourceProvider, ResourceContent, ResourceMetadata, 
    ResourceType, ResourceStatus
)

class FileResourceProvider(ResourceProvider):
    """本地文件系统资源提供者"""
    
    def __init__(self, base_paths: List[str], name: str = "file_provider"):
        super().__init__(name)
        self.base_paths = [Path(p).resolve() for p in base_paths]
        self.supported_schemes = {"file"}
        
        # 支持的文件类型映射
        self.type_mappings = {
            'text': ['.txt', '.md', '.py', '.js', '.html', '.css', '.json', '.xml'],
            'image': ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.svg'],
            'video': ['.mp4', '.avi', '.mov', '.wmv', '.flv'],
            'audio': ['.mp3', '.wav', '.flac', '.aac'],
            'binary': []  # 默认分类
        }
        
        logger.info(f"文件提供者初始化,基础路径: {self.base_paths}")
    
    async def supports_uri(self, uri: str) -> bool:
        """检查是否支持指定URI"""
        try:
            parsed = urlparse(uri)
            if parsed.scheme not in self.supported_schemes and not uri.startswith('/'):
                return False
            
            file_path = Path(parsed.path if parsed.scheme else uri)
            
            # 检查路径是否在允许的基础路径内
            for base_path in self.base_paths:
                try:
                    file_path.resolve().relative_to(base_path)
                    return True
                except ValueError:
                    continue
            
            return False
            
        except Exception as e:
            logger.debug(f"URI支持检查失败 {uri}: {e}")
            return False
    
    async def read_resource(self, uri: str) -> ResourceContent:
        """读取文件资源"""
        try:
            file_path = self._resolve_path(uri)
            
            if not file_path.exists():
                raise FileNotFoundError(f"文件不存在: {file_path}")
            
            if not file_path.is_file():
                raise ValueError(f"路径不是文件: {file_path}")
            
            # 获取文件信息
            stat_info = file_path.stat()
            mime_type, _ = mimetypes.guess_type(str(file_path))
            mime_type = mime_type or 'application/octet-stream'
            
            # 计算文件校验和
            checksum = await self._calculate_checksum(file_path)
            
            # 确定资源类型
            resource_type = self._determine_resource_type(file_path, mime_type)
            
            # 读取文件内容  
            if resource_type in [ResourceType.TEXT, ResourceType.JSON]:
                async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
                    content_data = await f.read()
                    
                # JSON类型尝试解析
                if resource_type == ResourceType.JSON:
                    try:
                        content_data = json.loads(content_data)
                    except json.JSONDecodeError:
                        logger.warning(f"JSON解析失败,回退到文本: {file_path}")
                        resource_type = ResourceType.TEXT
            else:
                # 二进制文件
                async with aiofiles.open(file_path, 'rb') as f:
                    content_data = await f.read()
            
            # 构建元数据
            metadata = ResourceMetadata(
                uri=uri,
                name=file_path.name,
                description=f"文件: {file_path}",
                resource_type=resource_type,
                mime_type=mime_type,
                size=stat_info.st_size,
                created_at=datetime.fromtimestamp(stat_info.st_ctime),
                updated_at=datetime.fromtimestamp(stat_info.st_mtime),
                checksum=checksum,
                tags=self._extract_tags(file_path)
            )
            
            return ResourceContent(
                data=content_data,
                metadata=metadata,
                status=ResourceStatus.AVAILABLE
            )
            
        except Exception as e:
            logger.error(f"读取文件资源失败 {uri}: {e}")
            raise
    
    async def list_resources(self, pattern: Optional[str] = None) -> List[ResourceMetadata]:
        """列出文件资源"""
        resources = []
        
        for base_path in self.base_paths:
            if not base_path.exists():
                continue
                
            # 递归遍历目录
            for file_path in base_path.rglob('*'):
                if not file_path.is_file():
                    continue
                
                # 模式匹配
                if pattern and not file_path.match(pattern):
                    continue
                
                try:
                    stat_info = file_path.stat()
                    mime_type, _ = mimetypes.guess_type(str(file_path))
                    mime_type = mime_type or 'application/octet-stream'
                    
                    resource_type = self._determine_resource_type(file_path, mime_type)
                    
                    # 相对于基础路径的URI
                    relative_path = file_path.relative_to(base_path)
                    uri = f"file://{relative_path}"
                    
                    metadata = ResourceMetadata(
                        uri=uri,
                        name=file_path.name,
                        description=f"文件: {relative_path}",
                        resource_type=resource_type,
                        mime_type=mime_type,
                        size=stat_info.st_size,
                        created_at=datetime.fromtimestamp(stat_info.st_ctime),
                        updated_at=datetime.fromtimestamp(stat_info.st_mtime),
                        tags=self._extract_tags(file_path)
                    )
                    
                    resources.append(metadata)
                    
                except Exception as e:
                    logger.warning(f"处理文件失败 {file_path}: {e}")
        
        return resources
    
    def _resolve_path(self, uri: str) -> Path:
        """解析URI到文件路径"""
        parsed = urlparse(uri)
        path_str = parsed.path if parsed.scheme else uri
        
        # 尝试相对于每个基础路径解析
        for base_path in self.base_paths:
            try:
                if path_str.startswith('/'):
                    # 绝对路径
                    candidate = Path(path_str).resolve()
                else:
                    # 相对路径
                    candidate = (base_path / path_str).resolve()
                
                # 检查路径安全性
                candidate.relative_to(base_path)
                return candidate
                
            except ValueError:
                continue
        
        raise ValueError(f"无法解析路径: {uri}")
    
    async def _calculate_checksum(self, file_path: Path) -> str:
        """计算文件校验和"""
        hash_md5 = hashlib.md5()
        
        async with aiofiles.open(file_path, 'rb') as f:
            while chunk := await f.read(8192):
                hash_md5.update(chunk)
        
        return hash_md5.hexdigest()
    
    def _determine_resource_type(self, file_path: Path, mime_type: str) -> ResourceType:
        """确定资源类型"""
        suffix = file_path.suffix.lower()
        
        if suffix == '.json':
            return ResourceType.JSON
        
        for type_name, suffixes in self.type_mappings.items():
            if suffix in suffixes:
                return ResourceType(type_name)
        
        # 根据MIME类型判断
        if mime_type.startswith('text/'):
            return ResourceType.TEXT
        elif mime_type.startswith('image/'):
            return ResourceType.IMAGE
        elif mime_type.startswith('video/'):
            return ResourceType.VIDEO
        elif mime_type.startswith('audio/'):
            return ResourceType.AUDIO
        else:
            return ResourceType.BINARY
    
    def _extract_tags(self, file_path: Path) -> List[str]:
        """从文件路径提取标签"""
        tags = []
        
        # 添加文件扩展名作为标签
        if file_path.suffix:
            tags.append(file_path.suffix[1:])  # 去掉点号
        
        # 添加父目录名作为标签
        if file_path.parent.name != file_path.parent.parent.name:
            tags.append(file_path.parent.name)
        
        # 根据文件名添加标签
        name_lower = file_path.stem.lower()
        if 'config' in name_lower:
            tags.append('configuration')
        if 'test' in name_lower:
            tags.append('testing')
        if 'doc' in name_lower or 'readme' in name_lower:
            tags.append('documentation')
        
        return tags

class DirectoryResourceProvider(FileResourceProvider):
    """目录结构资源提供者"""
    
    async def read_resource(self, uri: str) -> ResourceContent:
        """读取目录结构"""
        try:
            dir_path = self._resolve_path(uri)
            
            if not dir_path.exists():
                raise FileNotFoundError(f"目录不存在: {dir_path}")
            
            if not dir_path.is_dir():
                # 如果是文件,使用父类方法
                return await super().read_resource(uri)
            
            # 构建目录结构
            structure = await self._build_directory_structure(dir_path)
            
            metadata = ResourceMetadata(
                uri=uri,
                name=dir_path.name,
                description=f"目录结构: {dir_path}",
                resource_type=ResourceType.JSON,
                mime_type="application/json",
                tags=['directory', 'structure']
            )
            
            return ResourceContent(
                data=structure,
                metadata=metadata,
                status=ResourceStatus.AVAILABLE
            )
            
        except Exception as e:
            logger.error(f"读取目录资源失败 {uri}: {e}")
            raise
    
    async def _build_directory_structure(self, dir_path: Path) -> Dict[str, Any]:
        """构建目录结构数据"""
        structure = {
            "name": dir_path.name,
            "type": "directory",
            "path": str(dir_path),
            "children": []
        }
        
        try:
            for item in dir_path.iterdir():
                if item.is_dir():
                    child_structure = await self._build_directory_structure(item)
                    structure["children"].append(child_structure)
                else:
                    stat_info = item.stat()
                    structure["children"].append({
                        "name": item.name,
                        "type": "file",
                        "path": str(item),
                        "size": stat_info.st_size,
                        "modified": datetime.fromtimestamp(stat_info.st_mtime).isoformat()
                    })
        
        except PermissionError:
            structure["error"] = "权限不足"
        
        return structure

🗄️ 数据库资源提供者

python
"""
database_provider.py - 数据库资源提供者
"""

import asyncpg
import aiomysql
from typing import List, Optional, Dict, Any, Union
import json
from datetime import datetime

from .resource_manager import (
    ResourceProvider, ResourceContent, ResourceMetadata,
    ResourceType, ResourceStatus
)

class DatabaseResourceProvider(ResourceProvider):
    """数据库资源提供者"""
    
    def __init__(self, 
                 connection_config: Dict[str, Any],
                 db_type: str = "postgresql",
                 name: str = "database_provider"):
        super().__init__(name)
        self.connection_config = connection_config
        self.db_type = db_type.lower()
        self.supported_schemes = {"db", "database", "sql"}
        self.connection_pool = None
        
        logger.info(f"数据库提供者初始化: {db_type}")
    
    async def initialize(self):
        """初始化数据库连接池"""
        try:
            if self.db_type == "postgresql":
                self.connection_pool = await asyncpg.create_pool(**self.connection_config)
            elif self.db_type == "mysql":
                self.connection_pool = await aiomysql.create_pool(**self.connection_config)
            else:
                raise ValueError(f"不支持的数据库类型: {self.db_type}")
            
            logger.info("数据库连接池初始化成功")
            
        except Exception as e:
            logger.error(f"数据库连接池初始化失败: {e}")
            raise
    
    async def supports_uri(self, uri: str) -> bool:
        """检查是否支持指定URI"""
        try:
            parsed = urlparse(uri)
            return parsed.scheme in self.supported_schemes
        except:
            return False
    
    async def read_resource(self, uri: str) -> ResourceContent:
        """读取数据库资源"""
        if not self.connection_pool:
            await self.initialize()
        
        try:
            # 解析URI:db://table_name 或 db://query/custom_query
            parsed = urlparse(uri)
            path_parts = parsed.path.strip('/').split('/')
            
            if not path_parts or not path_parts[0]:
                raise ValueError("URI必须指定表名或查询")
            
            if path_parts[0] == "query" and len(path_parts) > 1:
                # 自定义查询
                query_name = path_parts[1]
                data = await self._execute_custom_query(query_name, parsed.query)
                description = f"自定义查询: {query_name}"
            else:
                # 表查询
                table_name = path_parts[0]
                data = await self._query_table(table_name, parsed.query)
                description = f"数据表: {table_name}"
            
            metadata = ResourceMetadata(
                uri=uri,
                name=path_parts[0],
                description=description,
                resource_type=ResourceType.JSON,
                mime_type="application/json",
                tags=['database', self.db_type, 'query']
            )
            
            return ResourceContent(
                data=data,
                metadata=metadata,
                status=ResourceStatus.AVAILABLE
            )
            
        except Exception as e:
            logger.error(f"读取数据库资源失败 {uri}: {e}")
            raise
    
    async def list_resources(self, pattern: Optional[str] = None) -> List[ResourceMetadata]:
        """列出数据库资源"""
        if not self.connection_pool:
            await self.initialize()
        
        resources = []
        
        try:
            # 获取所有表
            tables = await self._get_tables()
            
            for table_info in tables:
                table_name = table_info['table_name']
                
                if pattern and not table_name.match(pattern):
                    continue
                
                metadata = ResourceMetadata(
                    uri=f"db://{table_name}",
                    name=table_name,
                    description=f"数据表: {table_name}",
                    resource_type=ResourceType.JSON,
                    mime_type="application/json",
                    tags=['database', self.db_type, 'table']
                )
                
                resources.append(metadata)
            
            # 添加预定义查询
            custom_queries = await self._get_custom_queries()
            for query_name, query_info in custom_queries.items():
                metadata = ResourceMetadata(
                    uri=f"db://query/{query_name}",
                    name=query_name,
                    description=query_info.get('description', f"自定义查询: {query_name}"),
                    resource_type=ResourceType.JSON,
                    mime_type="application/json",
                    tags=['database', self.db_type, 'query', 'custom']
                )
                resources.append(metadata)
        
        except Exception as e:
            logger.error(f"列出数据库资源失败: {e}")
        
        return resources
    
    async def _query_table(self, table_name: str, query_params: str) -> Dict[str, Any]:
        """查询数据表"""
        # 解析查询参数
        params = {}
        if query_params:
            for param in query_params.split('&'):
                if '=' in param:
                    key, value = param.split('=', 1)
                    params[key] = value
        
        limit = int(params.get('limit', 100))
        offset = int(params.get('offset', 0))
        order_by = params.get('order_by', '')
        where_clause = params.get('where', '')
        
        # 构建安全的SQL查询
        query = f"SELECT * FROM {table_name}"
        query_params = []
        
        if where_clause:
            # 简单的WHERE子句支持(实际应用中需要更严格的SQL注入防护)
            query += f" WHERE {where_clause}"
        
        if order_by:
            query += f" ORDER BY {order_by}"
        
        query += f" LIMIT {limit} OFFSET {offset}"
        
        async with self.connection_pool.acquire() as conn:
            if self.db_type == "postgresql":
                rows = await conn.fetch(query)
                data = [dict(row) for row in rows]
            else:  # mysql
                async with conn.cursor() as cursor:
                    await cursor.execute(query)
                    rows = await cursor.fetchall()
                    columns = [desc[0] for desc in cursor.description]
                    data = [dict(zip(columns, row)) for row in rows]
        
        return {
            "table": table_name,
            "rows": data,
            "count": len(data),
            "limit": limit,
            "offset": offset,
            "query": query
        }
    
    async def _execute_custom_query(self, query_name: str, params: str) -> Dict[str, Any]:
        """执行自定义查询"""
        custom_queries = await self._get_custom_queries()
        
        if query_name not in custom_queries:
            raise ValueError(f"未找到自定义查询: {query_name}")
        
        query_info = custom_queries[query_name]
        query = query_info['query']
        
        # 解析参数
        query_params = {}
        if params:
            for param in params.split('&'):
                if '=' in param:
                    key, value = param.split('=', 1)
                    query_params[key] = value
        
        async with self.connection_pool.acquire() as conn:
            if self.db_type == "postgresql":
                rows = await conn.fetch(query, *query_params.values())
                data = [dict(row) for row in rows]
            else:  # mysql
                async with conn.cursor() as cursor:
                    await cursor.execute(query, list(query_params.values()))
                    rows = await cursor.fetchall()
                    columns = [desc[0] for desc in cursor.description]
                    data = [dict(zip(columns, row)) for row in rows]
        
        return {
            "query_name": query_name,
            "description": query_info.get('description', ''),
            "rows": data,
            "count": len(data),
            "parameters": query_params
        }
    
    async def _get_tables(self) -> List[Dict[str, Any]]:
        """获取数据库表列表"""
        if self.db_type == "postgresql":
            query = """
                SELECT table_name, table_type 
                FROM information_schema.tables 
                WHERE table_schema = 'public'
            """
        else:  # mysql
            query = """
                SELECT table_name, table_type 
                FROM information_schema.tables 
                WHERE table_schema = DATABASE()
            """
        
        async with self.connection_pool.acquire() as conn:
            if self.db_type == "postgresql":
                rows = await conn.fetch(query)
                return [dict(row) for row in rows]
            else:  # mysql
                async with conn.cursor() as cursor:
                    await cursor.execute(query)
                    rows = await cursor.fetchall()
                    return [{'table_name': row[0], 'table_type': row[1]} for row in rows]
    
    async def _get_custom_queries(self) -> Dict[str, Dict[str, Any]]:
        """获取自定义查询配置"""
        # 这里可以从配置文件或数据库中读取自定义查询
        # 示例配置
        return {
            "user_summary": {
                "query": "SELECT COUNT(*) as total_users, AVG(age) as avg_age FROM users",
                "description": "用户统计摘要"
            },
            "recent_orders": {
                "query": "SELECT * FROM orders WHERE created_at > NOW() - INTERVAL '7 days' ORDER BY created_at DESC",
                "description": "最近7天的订单"
            }
        }
    
    async def close(self):
        """关闭数据库连接池"""
        if self.connection_pool:
            if self.db_type == "postgresql":
                await self.connection_pool.close()
            else:  # mysql
                self.connection_pool.close()
                await self.connection_pool.wait_closed()
            
            logger.info("数据库连接池已关闭")

🌐 API资源提供者

python
"""
api_provider.py - API资源提供者
"""

import aiohttp
import asyncio
from typing import List, Optional, Dict, Any, Union
import json
from datetime import datetime, timedelta
from urllib.parse import urlparse, urlencode

from .resource_manager import (
    ResourceProvider, ResourceContent, ResourceMetadata,
    ResourceType, ResourceStatus
)

class APIResourceProvider(ResourceProvider):
    """API资源提供者"""
    
    def __init__(self, 
                 api_configs: Dict[str, Dict[str, Any]],
                 name: str = "api_provider"):
        super().__init__(name)
        self.api_configs = api_configs
        self.supported_schemes = {"http", "https", "api"}
        self.session = None
        self.rate_limits = {}  # 速率限制跟踪
        
        logger.info(f"API提供者初始化,配置的API数量: {len(api_configs)}")
    
    async def initialize(self):
        """初始化HTTP会话"""
        connector = aiohttp.TCPConnector(
            limit=100,  # 总连接池大小
            limit_per_host=30,  # 每个主机的连接数
            ttl_dns_cache=300,  # DNS缓存时间
            use_dns_cache=True
        )
        
        timeout = aiohttp.ClientTimeout(total=30, connect=10)
        
        self.session = aiohttp.ClientSession(
            connector=connector,
            timeout=timeout,
            headers={'User-Agent': 'MCP-ResourceProvider/1.0'}
        )
        
        logger.info("HTTP会话初始化完成")
    
    async def supports_uri(self, uri: str) -> bool:
        """检查是否支持指定URI"""
        try:
            parsed = urlparse(uri)
            
            # 直接HTTP/HTTPS URL
            if parsed.scheme in ['http', 'https']:
                return True
            
            # API配置中的命名资源
            if parsed.scheme == 'api':
                api_name = parsed.netloc
                return api_name in self.api_configs
            
            return False
            
        except:
            return False
    
    async def read_resource(self, uri: str) -> ResourceContent:
        """读取API资源"""
        if not self.session:
            await self.initialize()
        
        try:
            parsed = urlparse(uri)
            
            if parsed.scheme == 'api':
                # 使用配置的API
                return await self._read_configured_api(uri)
            else:
                # 直接HTTP请求
                return await self._read_http_resource(uri)
                
        except Exception as e:
            logger.error(f"读取API资源失败 {uri}: {e}")
            raise
    
    async def list_resources(self, pattern: Optional[str] = None) -> List[ResourceMetadata]:
        """列出API资源"""
        resources = []
        
        for api_name, config in self.api_configs.items():
            if pattern and api_name not in pattern:
                continue
            
            # 列出API的端点
            endpoints = config.get('endpoints', {})
            for endpoint_name, endpoint_config in endpoints.items():
                uri = f"api://{api_name}/{endpoint_name}"
                
                metadata = ResourceMetadata(
                    uri=uri,
                    name=f"{api_name}.{endpoint_name}",
                    description=endpoint_config.get('description', f"API端点: {endpoint_name}"),
                    resource_type=ResourceType.JSON,
                    mime_type="application/json",
                    tags=['api', api_name, endpoint_name]
                )
                
                resources.append(metadata)
        
        return resources
    
    async def _read_configured_api(self, uri: str) -> ResourceContent:
        """读取配置的API资源"""
        parsed = urlparse(uri)
        api_name = parsed.netloc
        endpoint_path = parsed.path.strip('/')
        
        if api_name not in self.api_configs:
            raise ValueError(f"未找到API配置: {api_name}")
        
        api_config = self.api_configs[api_name]
        endpoints = api_config.get('endpoints', {})
        
        if endpoint_path not in endpoints:
            raise ValueError(f"未找到端点配置: {endpoint_path}")
        
        endpoint_config = endpoints[endpoint_path]
        
        # 检查速率限制
        await self._check_rate_limit(api_name)
        
        # 构建请求
        base_url = api_config['base_url']
        url = f"{base_url.rstrip('/')}/{endpoint_config['path'].lstrip('/')}"
        
        # 处理查询参数
        params = {}
        if parsed.query:
            for param in parsed.query.split('&'):
                if '=' in param:
                    key, value = param.split('=', 1)
                    params[key] = value
        
        # 添加默认参数
        default_params = endpoint_config.get('default_params', {})
        params.update(default_params)
        
        # 构建请求头
        headers = api_config.get('headers', {})
        auth_config = api_config.get('auth', {})
        
        if auth_config.get('type') == 'bearer':
            headers['Authorization'] = f"Bearer {auth_config['token']}"
        elif auth_config.get('type') == 'api_key':
            key_name = auth_config.get('key_name', 'api_key')
            if auth_config.get('in') == 'header':
                headers[key_name] = auth_config['key']
            else:
                params[key_name] = auth_config['key']
        
        # 发起请求
        method = endpoint_config.get('method', 'GET').upper()
        
        async with self.session.request(
            method=method,
            url=url,
            params=params,
            headers=headers
        ) as response:
            
            content_type = response.headers.get('content-type', '').lower()
            
            if 'application/json' in content_type:
                data = await response.json()
                resource_type = ResourceType.JSON
            else:
                data = await response.text()
                resource_type = ResourceType.TEXT
            
            # 记录速率限制信息
            self._update_rate_limit(api_name, response.headers)
            
            metadata = ResourceMetadata(
                uri=uri,
                name=f"{api_name}.{endpoint_path}",
                description=endpoint_config.get('description', f"API响应: {endpoint_path}"),
                resource_type=resource_type,
                mime_type=content_type,
                tags=['api', api_name, endpoint_path, f'status_{response.status}']
            )
            
            return ResourceContent(
                data=data,
                metadata=metadata,
                status=ResourceStatus.AVAILABLE if response.status < 400 else ResourceStatus.ERROR,
                error_message=None if response.status < 400 else f"HTTP {response.status}"
            )
    
    async def _read_http_resource(self, uri: str) -> ResourceContent:
        """读取HTTP资源"""
        async with self.session.get(uri) as response:
            content_type = response.headers.get('content-type', '').lower()
            
            if 'application/json' in content_type:
                data = await response.json()
                resource_type = ResourceType.JSON
            elif content_type.startswith('text/'):
                data = await response.text()
                resource_type = ResourceType.TEXT
            else:
                data = await response.read()
                resource_type = ResourceType.BINARY
            
            parsed = urlparse(uri)
            
            metadata = ResourceMetadata(
                uri=uri,
                name=parsed.path.split('/')[-1] or parsed.netloc,
                description=f"HTTP资源: {uri}",
                resource_type=resource_type,
                mime_type=content_type,
                tags=['http', parsed.netloc, f'status_{response.status}']
            )
            
            return ResourceContent(
                data=data,
                metadata=metadata,
                status=ResourceStatus.AVAILABLE if response.status < 400 else ResourceStatus.ERROR,
                error_message=None if response.status < 400 else f"HTTP {response.status}"
            )
    
    async def _check_rate_limit(self, api_name: str):
        """检查速率限制"""
        if api_name not in self.rate_limits:
            return
        
        rate_limit_info = self.rate_limits[api_name]
        
        if rate_limit_info.get('remaining', 1) <= 0:
            reset_time = rate_limit_info.get('reset_time')
            if reset_time and datetime.now() < reset_time:
                wait_time = (reset_time - datetime.now()).total_seconds()
                logger.warning(f"API {api_name} 速率限制,等待 {wait_time:.1f} 秒")
                await asyncio.sleep(wait_time)
    
    def _update_rate_limit(self, api_name: str, headers: Dict[str, str]):
        """更新速率限制信息"""
        # 检查常见的速率限制头
        remaining = None
        reset_time = None
        
        # GitHub风格
        if 'x-ratelimit-remaining' in headers:
            remaining = int(headers['x-ratelimit-remaining'])
            reset_timestamp = int(headers.get('x-ratelimit-reset', 0))
            reset_time = datetime.fromtimestamp(reset_timestamp)
        
        # Twitter风格
        elif 'x-rate-limit-remaining' in headers:
            remaining = int(headers['x-rate-limit-remaining'])
            reset_timestamp = int(headers.get('x-rate-limit-reset', 0))
            reset_time = datetime.fromtimestamp(reset_timestamp)
        
        if remaining is not None:
            self.rate_limits[api_name] = {
                'remaining': remaining,
                'reset_time': reset_time,
                'updated_at': datetime.now()
            }
    
    async def close(self):
        """关闭HTTP会话"""
        if self.session:
            await self.session.close()
            logger.info("HTTP会话已关闭")

🧠 内存资源提供者

python
"""
memory_provider.py - 内存资源提供者
"""

from typing import Dict, List, Optional, Any, Union
import json
from datetime import datetime, timedelta
import asyncio
from collections import defaultdict

from .resource_manager import (
    ResourceProvider, ResourceContent, ResourceMetadata,
    ResourceType, ResourceStatus
)

class MemoryResourceProvider(ResourceProvider):
    """内存资源提供者"""
    
    def __init__(self, name: str = "memory_provider"):
        super().__init__(name)
        self.supported_schemes = {"memory", "mem", "cache"}
        self.resources: Dict[str, ResourceContent] = {}
        self.access_counts: Dict[str, int] = defaultdict(int)
        self.cleanup_task: Optional[asyncio.Task] = None
        
        logger.info("内存资源提供者初始化")
    
    async def initialize(self):
        """初始化清理任务"""
        if not self.cleanup_task:
            self.cleanup_task = asyncio.create_task(self._cleanup_expired_resources())
            logger.info("内存资源清理任务已启动")
    
    async def supports_uri(self, uri: str) -> bool:
        """检查是否支持指定URI"""
        try:
            parsed = urlparse(uri)
            return parsed.scheme in self.supported_schemes
        except:
            return False
    
    async def read_resource(self, uri: str) -> ResourceContent:
        """读取内存资源"""
        if uri not in self.resources:
            raise ValueError(f"内存中未找到资源: {uri}")
        
        resource = self.resources[uri]
        
        # 检查是否过期
        if resource.metadata.expires_at and datetime.now() > resource.metadata.expires_at:
            del self.resources[uri]
            raise ValueError(f"资源已过期: {uri}")
        
        # 更新访问计数
        self.access_counts[uri] += 1
        
        return resource
    
    async def list_resources(self, pattern: Optional[str] = None) -> List[ResourceMetadata]:
        """列出内存资源"""
        resources = []
        
        for uri, content in self.resources.items():
            # 检查是否过期
            if content.metadata.expires_at and datetime.now() > content.metadata.expires_at:
                continue
            
            if pattern and pattern not in uri:
                continue
            
            resources.append(content.metadata)
        
        return resources
    
    async def store_resource(self, 
                           uri: str,
                           data: Union[str, bytes, Dict[str, Any]],
                           metadata: Optional[ResourceMetadata] = None,
                           ttl: Optional[int] = None) -> ResourceContent:
        """存储资源到内存"""
        
        if metadata is None:
            # 自动推断资源类型
            if isinstance(data, dict):
                resource_type = ResourceType.JSON
                mime_type = "application/json"
            elif isinstance(data, bytes):
                resource_type = ResourceType.BINARY
                mime_type = "application/octet-stream"
            else:
                resource_type = ResourceType.TEXT
                mime_type = "text/plain"
            
            metadata = ResourceMetadata(
                uri=uri,
                name=uri.split('/')[-1],
                description=f"内存资源: {uri}",
                resource_type=resource_type,
                mime_type=mime_type,
                tags=['memory', 'cached']
            )
        
        # 设置过期时间
        if ttl:
            metadata.expires_at = datetime.now() + timedelta(seconds=ttl)
        
        content = ResourceContent(
            data=data,
            metadata=metadata,
            status=ResourceStatus.AVAILABLE
        )
        
        self.resources[uri] = content
        logger.debug(f"资源已存储到内存: {uri}")
        
        return content
    
    async def update_resource(self, uri: str, data: Union[str, bytes, Dict[str, Any]]) -> ResourceContent:
        """更新内存资源"""
        if uri not in self.resources:
            raise ValueError(f"内存中未找到资源: {uri}")
        
        resource = self.resources[uri]
        resource.data = data
        resource.metadata.updated_at = datetime.now()
        
        logger.debug(f"资源已更新: {uri}")
        return resource
    
    async def delete_resource(self, uri: str) -> bool:
        """删除内存资源"""
        if uri in self.resources:
            del self.resources[uri]
            self.access_counts.pop(uri, None)
            logger.debug(f"资源已删除: {uri}")
            return True
        return False
    
    async def _cleanup_expired_resources(self):
        """清理过期资源的后台任务"""
        while True:
            try:
                now = datetime.now()
                expired_uris = []
                
                for uri, content in self.resources.items():
                    if content.metadata.expires_at and now > content.metadata.expires_at:
                        expired_uris.append(uri)
                
                for uri in expired_uris:
                    await self.delete_resource(uri)
                    logger.debug(f"清理过期资源: {uri}")
                
                if expired_uris:
                    logger.info(f"清理了 {len(expired_uris)} 个过期资源")
                
                # 每分钟检查一次
                await asyncio.sleep(60)
                
            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"清理过期资源时出错: {e}")
                await asyncio.sleep(60)
    
    def get_statistics(self) -> Dict[str, Any]:
        """获取内存资源统计信息"""
        total_size = 0
        type_counts = defaultdict(int)
        
        for content in self.resources.values():
            if isinstance(content.data, str):
                total_size += len(content.data.encode('utf-8'))
            elif isinstance(content.data, bytes):
                total_size += len(content.data)
            elif isinstance(content.data, dict):
                total_size += len(json.dumps(content.data).encode('utf-8'))
            
            type_counts[content.metadata.resource_type.value] += 1
        
        return {
            "total_resources": len(self.resources),
            "total_size_bytes": total_size,
            "type_counts": dict(type_counts),
            "access_counts": dict(self.access_counts),
            "most_accessed": sorted(
                self.access_counts.items(), 
                key=lambda x: x[1], 
                reverse=True
            )[:5]
        }
    
    async def close(self):
        """关闭内存提供者"""
        if self.cleanup_task:
            self.cleanup_task.cancel()
            try:
                await self.cleanup_task
            except asyncio.CancelledError:
                pass
        
        self.resources.clear()
        self.access_counts.clear()
        logger.info("内存资源提供者已关闭")

🎯 本节小结

通过这一小节,你已经构建了一个完整的资源管理系统:

统一资源接口:标准化的资源访问模式
多种资源提供者:文件、数据库、API、内存
智能缓存机制:提升资源访问性能
动态资源发现:自动发现和注册资源
资源监控:实时监控资源变化

🚀 使用示例

python
# 初始化资源管理器
resource_manager = ResourceManager()

# 注册提供者
file_provider = FileResourceProvider(["/data", "/docs"])
db_provider = DatabaseResourceProvider({
    "host": "localhost", 
    "database": "myapp"
})
api_provider = APIResourceProvider({
    "github": {
        "base_url": "https://api.github.com",
        "endpoints": {...}
    }
})

resource_manager.registry.register_provider(file_provider)
resource_manager.registry.register_provider(db_provider)
resource_manager.registry.register_provider(api_provider)

# 读取资源
content = await resource_manager.read_resource("file:///data/config.json")
db_data = await resource_manager.read_resource("db://users?limit=10")
api_data = await resource_manager.read_resource("api://github/repos")

现在你的MCP服务器具备了强大的资源管理能力,可以统一管理各种类型的数据源!


👉 下一小节:5.2 提示词模板系统