Python项目-基于Python的网络爬虫与数据可视化系统

发布于:2025-03-27 ⋅ 阅读:(32) ⋅ 点赞:(0)

1. 项目简介

在当今数据驱动的时代,网络爬虫和数据可视化已成为获取、分析和展示信息的重要工具。本文将详细介绍如何使用Python构建一个完整的网络爬虫与数据可视化系统,该系统能够自动从互联网收集数据,进行处理分析,并通过直观的图表展示结果。

2. 技术栈

  • Python 3.8+:主要编程语言
  • 网络爬虫:Requests、BeautifulSoup4、Scrapy、Selenium
  • 数据处理:Pandas、NumPy
  • 数据可视化:Matplotlib、Seaborn、Plotly、Dash
  • 数据存储:SQLite、MongoDB
  • 其他工具:Jupyter Notebook、Flask

3. 系统架构

网络爬虫与数据可视化系统
├── 爬虫模块
│   ├── 数据采集器
│   ├── 解析器
│   └── 数据清洗器
├── 数据存储模块
│   ├── 关系型数据库接口
│   └── NoSQL数据库接口
├── 数据分析模块
│   ├── 统计分析
│   └── 数据挖掘
└── 可视化模块
    ├── 静态图表生成器
    ├── 交互式图表生成器
    └── Web展示界面

4. 爬虫模块实现

4.1 基础爬虫实现

首先,我们使用Requests和BeautifulSoup构建一个简单的爬虫:

import requests
from bs4 import BeautifulSoup
import pandas as pd

class BasicScraper:
    """基础网页爬虫类"""
    
    def __init__(self, user_agent=None):
        """初始化爬虫"""
        self.session = requests.Session()
        default_ua = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
        self.headers = {'User-Agent': user_agent if user_agent else default_ua}
    
    def fetch_page(self, url, params=None):
        """获取网页内容"""
        try:
            response = self.session.get(url, headers=self.headers, params=params)
            response.raise_for_status()  # 检查请求是否成功
            return response.text
        except requests.exceptions.RequestException as e:
            print(f"请求错误: {e}")
            return None
    
    def parse_html(self, html, parser='html.parser'):
        """解析HTML内容"""
        if html:
            return BeautifulSoup(html, parser)
        return None
    
    def extract_data(self, soup, selectors):
        """提取数据
        
        参数:
            soup: BeautifulSoup对象
            selectors: 字典,键为数据名称,值为CSS选择器
            
        返回:
            pandas.DataFrame: 提取的数据
        """
        data = {}
        for key, selector in selectors.items():
            elements = soup.select(selector)
            data[key] = [element.text.strip() for element in elements]
        
        # 确保所有列的长度一致
        max_length = max([len(v) for v in data.values()]) if data else 0
        for key in data:
            if len(data[key]) < max_length:
                data[key].extend([None] * (max_length - len(data[key])))
        
        return pd.DataFrame(data)
    
    def scrape(self, url, selectors, params=None):
        """执行完整的爬取过程"""
        html = self.fetch_page(url, params)
        if not html:
            return pd.DataFrame()
        
        soup = self.parse_html(html)
        if not soup:
            return pd.DataFrame()
        
        return self.extract_data(soup, selectors)


# 使用示例
def scrape_books_example():
    scraper = BasicScraper()
    url = "http://books.toscrape.com/"
    selectors = {
        "title": ".product_pod h3 a",
        "price": ".price_color",
        "rating": ".star-rating",
        "availability": ".availability"
    }
    
    # 爬取数据
    books_data = scraper.scrape(url, selectors)
    
    # 数据清洗
    if not books_data.empty:
        # 处理价格 - 移除货币符号并转换为浮点数
        books_data['price'] = books_data['price'].str.replace('£', '').astype(float)
        
        # 处理评分 - 从类名中提取星级
        books_data['rating'] = books_data['rating'].apply(lambda x: x.split()[1] + ' stars' if x else None)
        
        # 处理库存状态
        books_data['availability'] = books_data['availability'].str.strip()
    
    return books_data

# 执行爬取
if __name__ == "__main__":
    books = scrape_books_example()
    print(f"爬取到 {len(books)} 本书的信息")
    print(books.head())

4.2 使用Scrapy框架构建爬虫

对于更复杂的爬虫需求,我们可以使用Scrapy框架:

# 文件结构:
# my_scraper/
# ├── scrapy.cfg
# └── my_scraper/
#     ├── __init__.py
#     ├── items.py
#     ├── middlewares.py
#     ├── pipelines.py
#     ├── settings.py
#     └── spiders/
#         ├── __init__.py
#         └── book_spider.py

# items.py
import scrapy

class BookItem(scrapy.Item):
    """定义爬取的图书项目"""
    title = scrapy.Field()
    price = scrapy.Field()
    rating = scrapy.Field()
    availability = scrapy.Field()
    category = scrapy.Field()
    description = scrapy.Field()
    upc = scrapy.Field()
    image_url = scrapy.Field()
    url = scrapy.Field()

# book_spider.py
import scrapy
from ..items import BookItem

class BookSpider(scrapy.Spider):
    """图书爬虫"""
    name = 'bookspider'
    allowed_domains = ['books.toscrape.com']
    start_urls = ['http://books.toscrape.com/']
    
    def parse(self, response):
        """解析图书列表页面"""
        # 提取当前页面的所有图书
        books = response.css('article.product_pod')
        
        for book in books:
            # 获取图书详情页链接
            book_url = book.css('h3 a::attr(href)').get()
            if book_url:
                if 'catalogue/' not in book_url:
                    book_url = 'catalogue/' + book_url
                book_url = response.urljoin(book_url)
                yield scrapy.Request(book_url, callback=self.parse_book)
        
        # 处理分页
        next_page = response.css('li.next a::attr(href)').get()
        if next_page:
            yield response.follow(next_page, self.parse)
    
    def parse_book(self, response):
        """解析图书详情页面"""
        book = BookItem()
        
        # 提取基本信息
        book['title'] = response.css('div.product_main h1::text').get()
        book['price'] = response.css('p.price_color::text').get()
        book['availability'] = response.css('p.availability::text').extract()[1].strip()
        
        # 提取评分
        rating_class = response.css('p.star-rating::attr(class)').get()
        if rating_class:
            book['rating'] = rating_class.split()[1]
        
        # 提取产品信息表格
        rows = response.css('table.table-striped tr')
        for row in rows:
            header = row.css('th::text').get()
            if header == 'UPC':
                book['upc'] = row.css('td::text').get()
            elif header == 'Product Type':
                book['category'] = row.css('td::text').get()
        
        # 提取描述
        book['description'] = response.css('div#product_description + p::text').get()
        
        # 提取图片URL
        image_url = response.css('div.item.active img::attr(src)').get()
        if image_url:
            book['image_url'] = response.urljoin(image_url)
        
        book['url'] = response.url
        
        yield book

# pipelines.py (数据处理管道)
import re
from itemadapter import ItemAdapter

class BookPipeline:
    """图书数据处理管道"""
    
    def process_item(self, item, spider):
        adapter = ItemAdapter(item)
        
        # 清洗价格字段
        if adapter.get('price'):
            price_str = adapter['price']
            # 提取数字并转换为浮点数
            price_match = re.search(r'(\d+\.\d+)', price_str)
            if price_match:
                adapter['price'] = float(price_match.group(1))
        
        # 标准化评分
        rating_map = {
            'One': 1,
            'Two': 2,
            'Three': 3,
            'Four': 4,
            'Five': 5
        }
        if adapter.get('rating'):
            adapter['rating'] = rating_map.get(adapter['rating'], 0)
        
        # 处理库存信息
        if adapter.get('availability'):
            if 'In stock' in adapter['availability']:
                # 提取库存数量
                stock_match = re.search(r'(\d+)', adapter['availability'])
                if stock_match:
                    adapter['availability'] = int(stock_match.group(1))
                else:
                    adapter['availability'] = 'In stock'
            else:
                adapter['availability'] = 'Out of stock'
        
        return item

# 运行爬虫的脚本 (run_spider.py)
from scrapy.crawler import CrawlerProcess
from scrapy.utils.project import get_project_settings

def run_spider():
    """运行Scrapy爬虫"""
    process = CrawlerProcess(get_project_settings())
    process.crawl('bookspider')
    process.start()

if __name__ == '__main__':
    run_spider()

4.3 处理动态网页的爬虫

对于JavaScript渲染的网页,我们需要使用Selenium:

from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from webdriver_manager.chrome import ChromeDriverManager
import pandas as pd
import time
import logging

class DynamicScraper:
    """动态网页爬虫类"""
    
    def __init__(self, headless=True, wait_time=10):
        """初始化爬虫
        
        参数:
            headless: 是否使用无头模式
            wait_time: 等待元素出现的最大时间(秒)
        """
        self.wait_time = wait_time
        self.logger = self._setup_logger()
        self.driver = self._setup_driver(headless)
    
    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('DynamicScraper')
        logger.setLevel(logging.INFO)
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        
        return logger
    
    def _setup_driver(self, headless):
        """设置WebDriver"""
        try:
            chrome_options = Options()
            if headless:
                chrome_options.add_argument("--headless")
            
            # 添加其他有用的选项
            chrome_options.add_argument("--disable-gpu")
            chrome_options.add_argument("--no-sandbox")
            chrome_options.add_argument("--disable-dev-shm-usage")
            chrome_options.add_argument("--window-size=1920,1080")
            
            # 使用webdriver_manager自动管理ChromeDriver
            service = Service(ChromeDriverManager().install())
            driver = webdriver.Chrome(service=service, options=chrome_options)
            
            return driver
        except Exception as e:
            self.logger.error(f"设置WebDriver时出错: {e}")
            raise
    
    def navigate_to(self, url):
        """导航到指定URL"""
        try:
            self.logger.info(f"正在导航到: {url}")
            self.driver.get(url)
            return True
        except Exception as e:
            self.logger.error(f"导航到 {url} 时出错: {e}")
            return False
    
    def wait_for_element(self, by, value):
        """等待元素出现
        
        参数:
            by: 定位方式 (By.ID, By.CSS_SELECTOR 等)
            value: 定位值
            
        返回:
            找到的元素或None
        """
        try:
            element = WebDriverWait(self.driver, self.wait_time).until(
                EC.presence_of_element_located((by, value))
            )
            return element
        except Exception as e:
            self.logger.warning(f"等待元素 {value} 超时: {e}")
            return None
    
    def wait_for_elements(self, by, value):
        """等待多个元素出现"""
        try:
            elements = WebDriverWait(self.driver, self.wait_time).until(
                EC.presence_of_all_elements_located((by, value))
            )
            return elements
        except Exception as e:
            self.logger.warning(f"等待元素 {value} 超时: {e}")
            return []
    
    def extract_data(self, selectors):
        """从当前页面提取数据
        
        参数:
            selectors: 字典,键为数据名称,值为(定位方式, 定位值)元组
            
        返回:
            pandas.DataFrame: 提取的数据
        """
        data = {}
        
        for key, (by, value) in selectors.items():
            try:
                elements = self.driver.find_elements(by, value)
                data[key] = [element.text for element in elements]
                self.logger.info(f"提取了 {len(elements)} 个 '{key}' 元素")
            except Exception as e:
                self.logger.error(f"提取 '{key}' 数据时出错: {e}")
                data[key] = []
        
        # 确保所有列的长度一致
        max_length = max([len(v) for v in data.values()]) if data else 0
        for key in data:
            if len(data[key]) < max_length:
                data[key].extend([None] * (max_length - len(data[key])))
        
        return pd.DataFrame(data)
    
    def scroll_to_bottom(self, scroll_pause_time=1.0):
        """滚动到页面底部以加载更多内容"""
        self.logger.info("开始滚动页面以加载更多内容")
        
        # 获取初始页面高度
        last_height = self.driver.execute_script("return document.body.scrollHeight")
        
        while True:
            # 滚动到底部
            self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
            
            # 等待页面加载
            time.sleep(scroll_pause_time)
            
            # 计算新的页面高度并与上一个高度比较
            new_height = self.driver.execute_script("return document.body.scrollHeight")
            if new_height == last_height:
                # 如果高度没有变化,说明已经到底了
                break
            last_height = new_height
        
        self.logger.info("页面滚动完成")
    
    def click_element(self, by, value):
        """点击元素"""
        try:
            element = self.wait_for_element(by, value)
            if element:
                element.click()
                return True
            return False
        except Exception as e:
            self.logger.error(f"点击元素 {value} 时出错: {e}")
            return False
    
    def close(self):
        """关闭浏览器"""
        if self.driver:
            self.driver.quit()
            self.logger.info("浏览器已关闭")

# 使用示例
def scrape_dynamic_website_example():
    """爬取动态网站示例"""
    # 创建爬虫实例
    scraper = DynamicScraper(headless=True)
    
    try:
        # 导航到目标网站 (以SPA电商网站为例)
        url = "https://www.example-dynamic-site.com/products"
        if not scraper.navigate_to(url):
            return pd.DataFrame()
        
        # 等待页面加载完成
        scraper.wait_for_element(By.CSS_SELECTOR, ".product-grid")
        
        # 滚动页面以加载更多产品
        scraper.scroll_to_bottom(scroll_pause_time=2.0)
        
        # 定义要提取的数据选择器
        selectors = {
            "product_name": (By.CSS_SELECTOR, ".product-item .product-name"),
            "price": (By.CSS_SELECTOR, ".product-item .product-price"),
            "rating": (By.CSS_SELECTOR, ".product-item .product-rating"),
            "reviews_count": (By.CSS_SELECTOR, ".product-item .reviews-count")
        }
        
        # 提取数据
        products_data = scraper.extract_data(selectors)
        
        # 数据清洗
        if not products_data.empty:
            # 处理价格 - 移除货币符号并转换为浮点数
            products_data['price'] = products_data['price'].str.replace('$', '').str.replace(',', '').astype(float)
            
            # 处理评分 - 提取数值
            products_data['rating'] = products_data['rating'].str.extract(r'(\d\.\d)').astype(float)
            
            # 处理评论数 - 提取数值
            products_data['reviews_count'] = products_data['reviews_count'].str.extract(r'(\d+)').astype(int)
        
        return products_data
    
    finally:
        # 确保浏览器关闭
        scraper.close()

# 执行爬取
if __name__ == "__main__":
    products = scrape_dynamic_website_example()
    print(f"爬取到 {len(products)} 个产品的信息")
    print(products.head())

4.4 爬虫管理器

创建一个爬虫管理器来统一调用不同类型的爬虫:

class ScraperManager:
    """爬虫管理器,用于管理不同类型的爬虫"""
    
    def __init__(self):
        self.scrapers = {}
    
    def register_scraper(self, name, scraper_class, **kwargs):
        """注册爬虫
        
        参数:
            name: 爬虫名称
            scraper_class: 爬虫类
            kwargs: 传递给爬虫构造函数的参数
        """
        self.scrapers[name] = (scraper_class, kwargs)
        print(f"已注册爬虫: {name}")
    
    def get_scraper(self, name):
        """获取爬虫实例"""
        if name not in self.scrapers:
            raise ValueError(f"未找到名为 '{name}' 的爬虫")
        
        scraper_class, kwargs = self.scrapers[name]
        return scraper_class(**kwargs)
    
    def run_scraper(self, name, *args, **kwargs):
        """运行指定的爬虫
        
        参数:
            name: 爬虫名称
            args, kwargs: 传递给爬虫方法的参数
            
        返回:
            爬虫返回的数据
        """
        scraper = self.get_scraper(name)
        
        if hasattr(scraper, 'scrape'):
            return scraper.scrape(*args, **kwargs)
        elif hasattr(scraper, 'run'):
            return scraper.run(*args, **kwargs)
        else:
            raise AttributeError(f"爬虫 '{name}' 没有 'scrape' 或 'run' 方法")

# 使用示例
def scraper_manager_example():
    # 创建爬虫管理器
    manager = ScraperManager()
    
    # 注册基础爬虫
    manager.register_scraper('basic', BasicScraper)
    
    # 注册动态爬虫
    manager.register_scraper('dynamic', DynamicScraper, headless=True, wait_time=15)
    
    # 使用基础爬虫爬取数据
    url = "http://books.toscrape.com/"
    selectors = {
        "title": ".product_pod h3 a",
        "price": ".price_color",
        "rating": ".star-rating"
    }
    
    books_data = manager.run_scraper('basic', url, selectors)
    
    print(f"使用基础爬虫爬取到 {len(books_data)} 本书的信息")
    
    return books_data

# 执行示例
if __name__ == "__main__":
    data = scraper_manager_example()
    print(data.head())

4.5 代理IP和请求头轮换

为了避免被目标网站封锁,我们可以实现代理IP和请求头轮换功能:

import random
import time
from fake_useragent import UserAgent

class ProxyRotator:
    """代理IP轮换器"""
    
    def __init__(self, proxies=None):
        """初始化代理轮换器
        
        参数:
            proxies: 代理列表,格式为 [{'http': 'http://ip:port', 'https': 'https://ip:port'}, ...]
        """
        self.proxies = proxies or []
        self.current_index = 0
    
    def add_proxy(self, proxy):
        """添加代理"""
        self.proxies.append(proxy)
    
    def get_proxy(self):
        """获取下一个代理"""
        if not self.proxies:
            return None
        
        proxy = self.proxies[self.current_index]
        self.current_index = (self.current_index + 1) % len(self.proxies)
        return proxy
    
    def remove_proxy(self, proxy):
        """移除失效的代理"""
        if proxy in self.proxies:
            self.proxies.remove(proxy)
            self.current_index = self.current_index % max(1, len(self.proxies))

class UserAgentRotator:
    """User-Agent轮换器"""
    
    def __init__(self, use_fake_ua=True):
        """初始化User-Agent轮换器"""
        self.use_fake_ua = use_fake_ua
        self.ua = UserAgent() if use_fake_ua else None
        
        # 预定义的User-Agent列表(备用)
        self.user_agents = [
            'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
            'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.1 Safari/605.1.15',
            'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
            'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36',
            'Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1'
        ]
    
    def get_random_ua(self):
        """获取随机User-Agent"""
        if self.use_fake_ua and self.ua:
            try:
                return self.ua.random
            except:
                pass
        
        return random.choice(self.user_agents)

class EnhancedScraper(BasicScraper):
    """增强型爬虫,支持代理和请求头轮换"""
    
    def __init__(self, proxy_rotator=None, ua_rotator=None, retry_times=3, retry_delay=2):
        """初始化增强型爬虫
        
        参数:
            proxy_rotator: 代理轮换器
            ua_rotator: User-Agent轮换器
            retry_times: 请求失败重试次数
            retry_delay: 重试延迟时间(秒)
        """
        super().__init__()
        self.proxy_rotator = proxy_rotator or ProxyRotator()
        self.ua_rotator = ua_rotator or UserAgentRotator()
        self.retry_times = retry_times
        self.retry_delay = retry_delay
    
    def fetch_page(self, url, params=None):
        """获取网页内容,支持代理和重试"""
        for attempt in range(self.retry_times):
            try:
                # 获取代理和User-Agent
                proxy = self.proxy_rotator.get_proxy()
                user_agent = self.ua_rotator.get_random_ua()
                
                # 更新请求头
                self.headers['User-Agent'] = user_agent
                
                # 发送请求
                response = self.session.get(
                    url, 
                    headers=self.headers, 
                    params=params,
                    proxies=proxy,
                    timeout=10
                )
                
                # 检查请求是否成功
                response.raise_for_status()
                return response.text
            
            except requests.exceptions.RequestException as e:
                print(f"请求错误 (尝试 {attempt+1}/{self.retry_times}): {e}")
                
                # 如果是代理问题,移除当前代理
                if proxy and (isinstance(e, requests.exceptions.ProxyError) or 
                             isinstance(e, requests.exceptions.ConnectTimeout)):
                    self.proxy_rotator.remove_proxy(proxy)
                
                # 最后一次尝试失败
                if attempt == self.retry_times - 1:
                    return None
                
                # 延迟后重试
                time.sleep(self.retry_delay)
        
        return None

# 使用示例
def enhanced_scraper_example():
    # 创建代理轮换器
    proxy_rotator = ProxyRotator([
        {'http': 'http://proxy1.example.com:8080', 'https': 'https://proxy1.example.com:8080'},
        {'http': 'http://proxy2.example.com:8080', 'https': 'https://proxy2.example.com:8080'}
    ])
    
    # 创建User-Agent轮换器
    ua_rotator = UserAgentRotator()
    
    # 创建增强型爬虫
    scraper = EnhancedScraper(proxy_rotator, ua_rotator, retry_times=3)
    
    # 爬取数据
    url = "http://books.toscrape.com/"
    selectors = {
        "title": ".product_pod h3 a",
        "price": ".price_color",
        "rating": ".star-rating"
    }
    
    books_data = scraper.scrape(url, selectors)
    return books_data

# 执行示例
if __name__ == "__main__":
    data = enhanced_scraper_example()
    print(f"爬取到 {len(data)} 本书的信息")
    print(data.head())

5. 数据存储模块

数据存储模块负责将爬取的数据保存到不同类型的存储系统中,包括关系型数据库、NoSQL数据库和文件系统。

5.1 SQLite数据库存储

SQLite是一种轻量级的关系型数据库,适合单机应用和原型开发:

import sqlite3
import pandas as pd
import os
import logging
import csv
from datetime import datetime

class SQLiteStorage:
    """SQLite数据存储类"""
    
    def __init__(self, db_path):
        """初始化SQLite数据库连接
        
        参数:
            db_path: 数据库文件路径
        """
        self.db_path = db_path
        self.logger = self._setup_logger()
        
        # 确保数据库目录存在
        os.makedirs(os.path.dirname(os.path.abspath(db_path)), exist_ok=True)
        
        try:
            self.conn = sqlite3.connect(db_path)
            self.cursor = self.conn.cursor()
            self.logger.info(f"成功连接到SQLite数据库: {db_path}")
        except sqlite3.Error as e:
            self.logger.error(f"连接SQLite数据库时出错: {e}")
            raise
    
    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('SQLiteStorage')
        logger.setLevel(logging.INFO)
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        
        return logger
    
    def create_table(self, table_name, columns):
        """创建数据表
        
        参数:
            table_name: 表名
            columns: 列定义字典,键为列名,值为数据类型
        """
        try:
            # 构建CREATE TABLE语句
            columns_str = ', '.join([f"{col} {dtype}" for col, dtype in columns.items()])
            query = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_str})"
            
            # 执行SQL
            self.cursor.execute(query)
            self.conn.commit()
            self.logger.info(f"成功创建表: {table_name}")
            return True
        except sqlite3.Error as e:
            self.logger.error(f"创建表 {table_name} 时出错: {e}")
            self.conn.rollback()
            return False
    
    def insert_data(self, table_name, data):
        """插入数据
        
        参数:
            table_name: 表名
            data: 要插入的数据,可以是DataFrame或列表
        """
        try:
            if isinstance(data, pd.DataFrame):
                # 使用pandas的to_sql方法插入DataFrame
                data.to_sql(table_name, self.conn, if_exists='append', index=False)
                self.logger.info(f"成功插入 {len(data)} 行数据到表 {table_name}")
            elif isinstance(data, list) and len(data) > 0:
                # 处理列表数据
                if isinstance(data[0], dict):
                    # 字典列表
                    if not data:
                        return True
                    
                    # 获取所有键
                    columns = list(data[0].keys())
                    
                    # 准备INSERT语句
                    placeholders = ', '.join(['?'] * len(columns))
                    columns_str = ', '.join(columns)
                    query = f"INSERT INTO {table_name} ({columns_str}) VALUES ({placeholders})"
                    
                    # 准备数据
                    values = [[row.get(col) for col in columns] for row in data]
                    
                    # 执行插入
                    self.cursor.executemany(query, values)
                else:
                    # 值列表
                    placeholders = ', '.join(['?'] * len(data[0]))
                    query = f"INSERT INTO {table_name} VALUES ({placeholders})"
                    self.cursor.executemany(query, data)
                
                self.conn.commit()
                self.logger.info(f"成功插入 {len(data)} 行数据到表 {table_name}")
            else:
                self.logger.warning(f"没有数据可插入到表 {table_name}")
            
            return True
        except Exception as e:
            self.logger.error(f"插入数据到表 {table_name} 时出错: {e}")
            self.conn.rollback()
            return False
    
    def query_data(self, query, params=None):
        """执行查询
        
        参数:
            query: SQL查询语句
            params: 查询参数(可选)
            
        返回:
            pandas.DataFrame: 查询结果
        """
        try:
            if params:
                return pd.read_sql_query(query, self.conn, params=params)
            else:
                return pd.read_sql_query(query, self.conn)
        except Exception as e:
            self.logger.error(f"执行查询时出错: {e}")
            return pd.DataFrame()
    
    def execute_query(self, query, params=None):
        """执行任意SQL查询
        
        参数:
            query: SQL查询语句
            params: 查询参数(可选)
            
        返回:
            bool: 是否成功
        """
        try:
            if params:
                self.cursor.execute(query, params)
            else:
                self.cursor.execute(query)
            
            self.conn.commit()
            return True
        except Exception as e:
            self.logger.error(f"执行查询时出错: {e}")
            self.conn.rollback()
            return False
    
    def table_exists(self, table_name):
        """检查表是否存在
        
        参数:
            table_name: 表名
            
        返回:
            bool: 表是否存在
        """
        query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
        self.cursor.execute(query, (table_name,))
        return self.cursor.fetchone() is not None
    
    def get_table_info(self, table_name):
        """获取表信息
        
        参数:
            table_name: 表名
            
        返回:
            list: 表的列信息
        """
        if not self.table_exists(table_name):
            return []
        
        query = f"PRAGMA table_info({table_name})"
        return self.cursor.execute(query).fetchall()
    
    def close(self):
        """关闭数据库连接"""
        if hasattr(self, 'conn') and self.conn:
            self.conn.close()
            self.logger.info("数据库连接已关闭")
    
    def __enter__(self):
        """上下文管理器入口"""
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """上下文管理器退出"""
        self.close()

# 使用示例
def sqlite_example():
    # 创建SQLite存储实例
    db = SQLiteStorage('data/books.db')
    
    try:
        # 创建表
        db.create_table('books', {
            'id': 'INTEGER PRIMARY KEY AUTOINCREMENT',
            'title': 'TEXT NOT NULL',
            'price': 'REAL',
            'rating': 'INTEGER',
            'category': 'TEXT',
            'description': 'TEXT',
            'created_at': 'TIMESTAMP DEFAULT CURRENT_TIMESTAMP'
        })
        
        # 准备示例数据
        books_data = pd.DataFrame({
            'title': ['Python编程', '数据科学入门', '机器学习实战'],
            'price': [59.9, 69.9, 79.9],
            'rating': [5, 4, 5],
            'category': ['编程', '数据科学', '机器学习'],
            'description': ['Python基础教程', '数据分析入门', '机器学习算法详解']
        })
        
        # 插入数据
        db.insert_data('books', books_data)
        
        # 查询数据
        results = db.query_data("SELECT * FROM books WHERE rating >= ?", (4,))
        print(f"查询结果: {len(results)} 行")
        print(results)
        
        return results
    
    finally:
        # 确保关闭连接
        db.close()

if __name__ == "__main__":
    sqlite_example()

5.2 MongoDB数据库存储

MongoDB是一种流行的NoSQL数据库,适合存储非结构化或半结构化数据:

import pymongo
import pandas as pd
import json
import logging
from bson import ObjectId
from datetime import datetime

class MongoDBStorage:
    """MongoDB数据存储类"""
    
    def __init__(self, connection_string, database_name):
        """初始化MongoDB连接
        
        参数:
            connection_string: MongoDB连接字符串
            database_name: 数据库名称
        """
        self.connection_string = connection_string
        self.database_name = database_name
        self.logger = self._setup_logger()
        
        try:
            # 连接到MongoDB
            self.client = pymongo.MongoClient(connection_string)
            self.db = self.client[database_name]
            
            # 测试连接
            self.client.server_info()
            self.logger.info(f"成功连接到MongoDB数据库: {database_name}")
        except Exception as e:
            self.logger.error(f"连接MongoDB数据库时出错: {e}")
            raise
    
    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('MongoDBStorage')
        logger.setLevel(logging.INFO)
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        
        return logger
    
    def _convert_to_json_serializable(self, data):
        """转换数据为JSON可序列化格式
        
        参数:
            data: 要转换的数据
            
        返回:
            转换后的数据
        """
        if isinstance(data, dict):
            return {k: self._convert_to_json_serializable(v) for k, v in data.items()}
        elif isinstance(data, list):
            return [self._convert_to_json_serializable(item) for item in data]
        elif isinstance(data, (ObjectId, datetime)):
            return str(data)
        else:
            return data
    
    def insert_document(self, collection_name, document):
        """插入单个文档
        
        参数:
            collection_name: 集合名称
            document: 要插入的文档(字典)
            
        返回:
            插入的文档ID
        """
        try:
            collection = self.db[collection_name]
            result = collection.insert_one(document)
            self.logger.info(f"成功插入文档到集合 {collection_name}, ID: {result.inserted_id}")
            return result.inserted_id
        except Exception as e:
            self.logger.error(f"插入文档到集合 {collection_name} 时出错: {e}")
            return None
    
    def insert_many(self, collection_name, documents):
        """插入多个文档
        
        参数:
            collection_name: 集合名称
            documents: 要插入的文档列表
            
        返回:
            插入的文档ID列表
        """
        try:
            collection = self.db[collection_name]
            result = collection.insert_many(documents)
            self.logger.info(f"成功插入 {len(result.inserted_ids)} 个文档到集合 {collection_name}")
            return result.inserted_ids
        except Exception as e:
            self.logger.error(f"插入多个文档到集合 {collection_name} 时出错: {e}")
            return []
    
    def insert_dataframe(self, collection_name, df):
        """插入DataFrame数据
        
        参数:
            collection_name: 集合名称
            df: pandas DataFrame
            
        返回:
            bool: 是否成功
        """
        try:
            if df.empty:
                self.logger.warning(f"DataFrame为空,未插入数据到集合 {collection_name}")
                return True
            
            # 将DataFrame转换为字典列表
            records = df.to_dict('records')
            
            # 插入数据
            collection = self.db[collection_name]
            result = collection.insert_many(records)
            
            self.logger.info(f"成功插入 {len(result.inserted_ids)} 行数据到集合 {collection_name}")
            return True
        except Exception as e:
            self.logger.error(f"插入DataFrame到集合 {collection_name} 时出错: {e}")
            return False
    
    def find_documents(self, collection_name, query=None, projection=None, limit=0):
        """查询文档
        
        参数:
            collection_name: 集合名称
            query: 查询条件(可选)
            projection: 投影字段(可选)
            limit: 结果限制数量(可选)
            
        返回:
            pandas.DataFrame: 查询结果
        """
        try:
            collection = self.db[collection_name]
            
            # 执行查询
            if query is None:
                query = {}
            
            cursor = collection.find(query, projection)
            
            if limit > 0:
                cursor = cursor.limit(limit)
            
            # 将结果转换为列表
            results = list(cursor)
            
            # 将ObjectId转换为字符串
            for doc in results:
                if '_id' in doc:
                    doc['_id'] = str(doc['_id'])
            
            # 转换为DataFrame
            if results:
                return pd.DataFrame(results)
            else:
                return pd.DataFrame()
        except Exception as e:
            self.logger.error(f"查询集合 {collection_name} 时出错: {e}")
            return pd.DataFrame()
    
    def update_document(self, collection_name, query, update_data, upsert=False):
        """更新文档
        
        参数:
            collection_name: 集合名称
            query: 查询条件
            update_data: 更新数据
            upsert: 如果不存在是否插入
            
        返回:
            int: 更新的文档数量
        """
        try:
            collection = self.db[collection_name]
            
            # 确保update_data使用$set操作符
            if not any(k.startswith('$') for k in update_data.keys()):
                update_data = {'$set': update_data}
            
            result = collection.update_one(query, update_data, upsert=upsert)
            
            self.logger.info(f"更新集合 {collection_name} 中的文档: 匹配 {result.matched_count}, 修改 {result.modified_count}")
            return result.modified_count
        except Exception as e:
            self.logger.error(f"更新集合 {collection_name} 中的文档时出错: {e}")
            return 0
    
    def update_many(self, collection_name, query, update_data):
        """更新多个文档
        
        参数:
            collection_name: 集合名称
            query: 查询条件
            update_data: 更新数据
            
        返回:
            int: 更新的文档数量
        """
        try:
            collection = self.db[collection_name]
            
            # 确保update_data使用$set操作符
            if not any(k.startswith('$') for k in update_data.keys()):
                update_data = {'$set': update_data}
            
            result = collection.update_many(query, update_data)
            
            self.logger.info(f"更新集合 {collection_name} 中的多个文档: 匹配 {result.matched_count}, 修改 {result.modified_count}")
            return result.modified_count
        except Exception as e:
            self.logger.error(f"更新集合 {collection_name} 中的多个文档时出错: {e}")
            return 0
    
    def delete_document(self, collection_name, query):
        """删除文档
        
        参数:
            collection_name: 集合名称
            query: 查询条件
            
        返回:
            int: 删除的文档数量
        """
        try:
            collection = self.db[collection_name]
            result = collection.delete_one(query)
            
            self.logger.info(f"从集合 {collection_name} 中删除了 {result.deleted_count} 个文档")
            return result.deleted_count
        except Exception as e:
            self.logger.error(f"从集合 {collection_name} 中删除文档时出错: {e}")
            return 0
    
    def delete_many(self, collection_name, query):
        """删除多个文档
        
        参数:
            collection_name: 集合名称
            query: 查询条件
            
        返回:
            int: 删除的文档数量
        """
        try:
            collection = self.db[collection_name]
            result = collection.delete_many(query)
            
            self.logger.info(f"从集合 {collection_name} 中删除了 {result.deleted_count} 个文档")
            return result.deleted_count
        except Exception as e:
            self.logger.error(f"从集合 {collection_name} 中删除多个文档时出错: {e}")
            return 0
    
    def create_index(self, collection_name, keys, **kwargs):
        """创建索引
        
        参数:
            collection_name: 集合名称
            keys: 索引键
            **kwargs: 其他索引选项
            
        返回:
            str: 创建的索引名称
        """
        try:
            collection = self.db[collection_name]
            index_name = collection.create_index(keys, **kwargs)
            
            self.logger.info(f"在集合 {collection_name} 上创建索引: {index_name}")
            return index_name
        except Exception as e:
            self.logger.error(f"在集合 {collection_name} 上创建索引时出错: {e}")
            return None
    
    def drop_collection(self, collection_name):
        """删除集合
        
        参数:
            collection_name: 集合名称
            
        返回:
            bool: 是否成功
        """
        try:
            self.db.drop_collection(collection_name)
            self.logger.info(f"成功删除集合: {collection_name}")
            return True
        except Exception as e:
            self.logger.error(f"删除集合 {collection_name} 时出错: {e}")
            return False
    
    def close(self):
        """关闭数据库连接"""
        if hasattr(self, 'client') and self.client:
            self.client.close()
            self.logger.info("MongoDB连接已关闭")
    
    def __enter__(self):
        """上下文管理器入口"""
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """上下文管理器退出"""
        self.close()

# 使用示例
def mongodb_example():
    # 创建MongoDB存储实例
    mongo = MongoDBStorage('mongodb://localhost:27017', 'web_scraping_db')
    
    try:
        # 准备示例数据
        products_data = pd.DataFrame({
            'name': ['智能手机', '笔记本电脑', '平板电脑'],
            'price': [2999, 4999, 3999],
            'brand': ['品牌A', '品牌B', '品牌A'],
            'features': [
                ['5G', '高清摄像头', '快速充电'],
                ['高性能CPU', '大内存', 'SSD'],
                ['触控屏', '长续航', '轻薄']
            ],
            'in_stock': [True, False, True],
            'last_updated': [datetime.now() for _ in range(3)]
        })
        
        # 插入DataFrame数据
        mongo.insert_dataframe('products', products_data)
        
        # 插入单个文档
        review = {
            'product_id': '123456',
            'user': '用户A',
            'rating': 5,
            'comment': '非常好用的产品',
            'date': datetime.now()
        }
        review_id = mongo.insert_document('reviews', review)
        
        # 查询数据
        results = mongo.find_documents('products', {'brand': '品牌A'})
        print(f"查询结果: {len(results)} 行")
        print(results)
        
        # 更新数据
        mongo.update_document('products', {'name': '智能手机'}, {'$set': {'price': 2899}})
        
        # 创建索引
        mongo.create_index('products', [('name', pymongo.ASCENDING)], unique=True)
        
        return results
    
    finally:
        # 确保关闭连接
        mongo.close()

if __name__ == "__main__":
    mongodb_example()

5.3 CSV文件存储

CSV是一种常用的数据交换格式,适合存储表格数据:

import pandas as pd
import os
import logging
import csv
from datetime import datetime

class CSVStorage:
    """CSV文件存储类"""
    
    def __init__(self, base_dir='data/csv'):
        """初始化CSV存储
        
        参数:
            base_dir: CSV文件存储的基础目录
        """
        self.base_dir = base_dir
        self.logger = self._setup_logger()
        
        # 确保目录存在
        os.makedirs(base_dir, exist_ok=True)
        self.logger.info(f"CSV存储目录: {base_dir}")
    
    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('CSVStorage')
        logger.setLevel(logging.INFO)
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        
        return logger
    
    def _get_file_path(self, file_name):
        """获取文件完整路径
        
        参数:
            file_name: 文件名
            
        返回:
            str: 文件完整路径
        """
        # 确保文件名有.csv后缀
        if not file_name.endswith('.csv'):
            file_name += '.csv'
        
        return os.path.join(self.base_dir, file_name)
    
    def save_dataframe(self, df, file_name, index=False):
        """保存DataFrame到CSV文件
        
        参数:
            df: 要保存的DataFrame
            file_name: 文件名
            index: 是否保存索引
            
        返回:
            bool: 是否成功
        """
        try:
            file_path = self._get_file_path(file_name)
            df.to_csv(file_path, index=index, encoding='utf-8')
            self.logger.info(f"成功保存 {len(df)} 行数据到文件: {file_path}")
            return True
        except Exception as e:
            self.logger.error(f"保存数据到文件 {file_name} 时出错: {e}")
            return False
    
    def append_dataframe(self, df, file_name, index=False):
        """追加DataFrame到CSV文件
        
        参数:
            df: 要追加的DataFrame
            file_name: 文件名
            index: 是否保存索引
            
        返回:
            bool: 是否成功
        """
        try:
            file_path = self._get_file_path(file_name)
            
            # 检查文件是否存在
            file_exists = os.path.isfile(file_path)
            
            # 如果文件存在,追加数据;否则创建新文件
            df.to_csv(file_path, mode='a', header=not file_exists, index=index, encoding='utf-8')
            
            self.logger.info(f"成功追加 {len(df)} 行数据到文件: {file_path}")
            return True
        except Exception as e:
            self.logger.error(f"追加数据到文件 {file_name} 时出错: {e}")
            return False
    
    def load_csv(self, file_name, **kwargs):
        """加载CSV文件到DataFrame
        
        参数:
            file_name: 文件名
            **kwargs: 传递给pd.read_csv的参数
            
        返回:
            pandas.DataFrame: 加载的数据
        """
        try:
            file_path = self._get_file_path(file_name)
            
            if not os.path.isfile(file_path):
                self.logger.warning(f"文件不存在: {file_path}")
                return pd.DataFrame()
            
            df = pd.read_csv(file_path, **kwargs)
            self.logger.info(f"成功从文件 {file_path} 加载 {len(df)} 行数据")
            return df
        except Exception as e:
            self.logger.error(f"从文件 {file_name} 加载数据时出错: {e}")
            return pd.DataFrame()
    
    def save_records(self, records, file_name, fieldnames=None):
        """保存记录列表到CSV文件
        
        参数:
            records: 字典列表
            file_name: 文件名
            fieldnames: 字段名列表(可选)
            
        返回:
            bool: 是否成功
        """
        try:
            file_path = self._get_file_path(file_name)
            
            if not records:
                self.logger.warning(f"没有记录可保存到文件: {file_path}")
                return True
            
            # 如果未提供字段名,使用第一条记录的键
            if fieldnames is None:
                fieldnames = list(records[0].keys())
            
            with open(file_path, 'w', newline='', encoding='utf-8') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
                writer.writerows(records)
            
            self.logger.info(f"成功保存 {len(records)} 条记录到文件: {file_path}")
            return True
        except Exception as e:
            self.logger.error(f"保存记录到文件 {file_name} 时出错: {e}")
            return False
    
    def append_records(self, records, file_name, fieldnames=None):
        """追加记录列表到CSV文件
        
        参数:
            records: 字典列表
            file_name: 文件名
            fieldnames: 字段名列表(可选)
            
        返回:
            bool: 是否成功
        """
        try:
            file_path = self._get_file_path(file_name)
            
            if not records:
                self.logger.warning(f"没有记录可追加到文件: {file_path}")
                return True
            
            # 检查文件是否存在
            file_exists = os.path.isfile(file_path)
            
            # 如果未提供字段名,使用第一条记录的键
            if fieldnames is None:
                fieldnames = list(records[0].keys())
            
            with open(file_path, 'a', newline='', encoding='utf-8') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                
                # 如果文件不存在,写入表头
                if not file_exists:
                    writer.writeheader()
                
                writer.writerows(records)
            
            self.logger.info(f"成功追加 {len(records)} 条记录到文件: {file_path}")
            return True
        except Exception as e:
            self.logger.error(f"追加记录到文件 {file_name} 时出错: {e}")
            return False
    
    def file_exists(self, file_name):
        """检查文件是否存在
        
        参数:
            file_name: 文件名
            
        返回:
            bool: 文件是否存在
        """
        file_path = self._get_file_path(file_name)
        return os.path.isfile(file_path)
    
    def list_files(self):
        """列出所有CSV文件
        
        返回:
            list: CSV文件列表
        """
        try:
            files = [f for f in os.listdir(self.base_dir) if f.endswith('.csv')]
            self.logger.info(f"找到 {len(files)} 个CSV文件")
            return files
        except Exception as e:
            self.logger.error(f"列出CSV文件时出错: {e}")
            return []
    
    def delete_file(self, file_name):
        """删除CSV文件
        
        参数:
            file_name: 文件名
            
        返回:
            bool: 是否成功
        """
        try:
            file_path = self._get_file_path(file_name)
            
            if not os.path.isfile(file_path):
                self.logger.warning(f"文件不存在,无法删除: {file_path}")
                return False
            
            os.remove(file_path)
            self.logger.info(f"成功删除文件: {file_path}")
            return True
        except Exception as e:
            self.logger.error(f"删除文件 {file_name} 时出错: {e}")
            return False

# 使用示例
def csv_example():
    # 创建CSV存储实例
    csv_storage = CSVStorage('data/csv')
    
    # 准备示例数据
    data = pd.DataFrame({
        'date': [datetime.now().strftime('%Y-%m-%d %H:%M:%S') for _ in range(3)],
        'category': ['电子产品', '家居', '食品'],
        'item_count': [120, 85, 200],
        'average_price': [1500.75, 350.25, 45.50]
    })
    
    # 保存数据
    csv_storage.save_dataframe(data, 'inventory')
    
    # 加载数据
    loaded_data = csv_storage.load_csv('inventory')
    print(f"加载的数据: {len(loaded_data)} 行")
    print(loaded_data)
    
    # 追加数据
    new_data = pd.DataFrame({
        'date': [datetime.now().strftime('%Y-%m-%d %H:%M:%S')],
        'category': ['服装'],
        'item_count': [150],
        'average_price': [250.00]
    })
    csv_storage.append_dataframe(new_data, 'inventory')
    
    return loaded_data

if __name__ == "__main__":
    csv_example()

5.4 存储工厂

创建一个存储工厂类,用于统一管理不同类型的存储:

class StorageFactory:
    """存储工厂类,用于创建和管理不同类型的存储"""
    
    def __init__(self):
        self.storage_classes = {}
        self.storage_instances = {}
    
    def register_storage(self, storage_type, storage_class):
        """注册存储类
        
        参数:
            storage_type: 存储类型名称
            storage_class: 存储类
        """
        self.storage_classes[storage_type] = storage_class
        print(f"已注册存储类型: {storage_type}")
    
    def get_storage(self, storage_type, **kwargs):
        """获取存储实例
        
        参数:
            storage_type: 存储类型名称
            **kwargs: 传递给存储类构造函数的参数
            
        返回:
            存储实例
        """
        # 检查存储类型是否已注册
        if storage_type not in self.storage_classes:
            raise ValueError(f"未注册的存储类型: {storage_type}")
        
        # 创建存储实例的键
        instance_key = f"{storage_type}_{hash(frozenset(kwargs.items()))}"
        
        # 如果实例不存在,创建新实例
        if instance_key not in self.storage_instances:
            storage_class = self.storage_classes[storage_type]
            self.storage_instances[instance_key] = storage_class(**kwargs)
        
        return self.storage_instances[instance_key]
    
    def close_all(self):
        """关闭所有存储连接"""
        for instance_key, storage in self.storage_instances.items():
            if hasattr(storage, 'close'):
                storage.close()
        
        self.storage_instances.clear()
        print("已关闭所有存储连接")

# 使用示例
def storage_factory_example():
    # 创建存储工厂
    factory = StorageFactory()
    
    # 注册存储类
    factory.register_storage('sqlite', SQLiteStorage)
    factory.register_storage('mongodb', MongoDBStorage)
    factory.register_storage('csv', CSVStorage)
    
    # 获取SQLite存储实例
    sqlite_storage = factory.get_storage('sqlite', db_path='data/example.db')
    
    # 获取MongoDB存储实例
    mongo_storage = factory.get_storage('mongodb', 
                                       connection_string='mongodb://localhost:27017', 
                                       database_name='example_db')
    
    # 获取CSV存储实例
    csv_storage = factory.get_storage('csv', base_dir='data/csv_files')
    
    # 使用存储实例...
    
    # 关闭所有连接
    factory.close_all()
    
    return "存储工厂示例完成"

if __name__ == "__main__":
    storage_factory_example()

6. 数据分析模块

数据分析模块负责对爬取的数据进行清洗、转换、分析和挖掘,从而提取有价值的信息和洞察。

6.1 数据清洗与预处理

数据清洗是数据分析的第一步,用于处理缺失值、异常值和格式不一致的数据:

import pandas as pd
import numpy as np
import re
from datetime import datetime
import logging

class DataCleaner:
    """数据清洗类"""
    
    def __init__(self):
        """初始化数据清洗器"""
        self.logger = self._setup_logger()
    
    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('DataCleaner')
        logger.setLevel(logging.INFO)
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        
        return logger
    
    def handle_missing_values(self, df, strategy='drop', fill_value=None):
        """处理缺失值
        
        参数:
            df: 输入DataFrame
            strategy: 处理策略,可选'drop'(删除)、'fill'(填充)
            fill_value: 填充值,当strategy为'fill'时使用
            
        返回:
            处理后的DataFrame
        """
        if df.empty:
            self.logger.warning("输入DataFrame为空")
            return df
        
        missing_count = df.isnull().sum().sum()
        self.logger.info(f"检测到 {missing_count} 个缺失值")
        
        if missing_count == 0:
            return df
        
        if strategy == 'drop':
            # 删除包含缺失值的行
            result = df.dropna()
            self.logger.info(f"删除了 {len(df) - len(result)} 行含有缺失值的数据")
            return result
        
        elif strategy == 'fill':
            # 填充缺失值
            if isinstance(fill_value, dict):
                # 对不同列使用不同的填充值
                result = df.fillna(fill_value)
                self.logger.info(f"使用指定值填充了缺失值: {fill_value}")
            else:
                # 使用相同的值填充所有缺失值
                result = df.fillna(fill_value)
                self.logger.info(f"使用 {fill_value} 填充了所有缺失值")
            return result
        
        else:
            self.logger.error(f"未知的缺失值处理策略: {strategy}")
            return df
    
    def remove_duplicates(self, df, subset=None):
        """删除重复行
        
        参数:
            df: 输入DataFrame
            subset: 用于判断重复的列,默认使用所有列
            
        返回:
            处理后的DataFrame
        """
        if df.empty:
            return df
        
        # 删除重复行
        result = df.drop_duplicates(subset=subset)
        
        removed_count = len(df) - len(result)
        self.logger.info(f"删除了 {removed_count} 行重复数据")
        
        return result
    
    def handle_outliers(self, df, columns, method='zscore', threshold=3.0):
        """处理异常值
        
        参数:
            df: 输入DataFrame
            columns: 要处理的列名列表
            method: 异常值检测方法,可选'zscore'、'iqr'
            threshold: 阈值,zscore方法使用
            
        返回:
            处理后的DataFrame
        """
        if df.empty:
            return df
        
        result = df.copy()
        outliers_count = 0
        
        for col in columns:
            if col not in df.columns:
                self.logger.warning(f"列 {col} 不存在")
                continue
            
            if not pd.api.types.is_numeric_dtype(df[col]):
                self.logger.warning(f"列 {col} 不是数值类型,跳过异常值检测")
                continue
            
            # 获取非缺失值
            values = df[col].dropna()
            
            if method == 'zscore':
                # 使用Z-score方法检测异常值
                mean = values.mean()
                std = values.std()
                if std == 0:
                    self.logger.warning(f"列 {col} 的标准差为0,跳过异常值检测")
                    continue
                
                z_scores = np.abs((values - mean) / std)
                outliers = values[z_scores > threshold].index
                
            elif method == 'iqr':
                # 使用IQR方法检测异常值
                q1 = values.quantile(0.25)
                q3 = values.quantile(0.75)
                iqr = q3 - q1
                lower_bound = q1 - 1.5 * iqr
                upper_bound = q3 + 1.5 * iqr
                outliers = values[(values < lower_bound) | (values > upper_bound)].index
                
            else:
                self.logger.error(f"未知的异常值检测方法: {method}")
                continue
            
            # 将异常值设为NaN
            result.loc[outliers, col] = np.nan
            outliers_count += len(outliers)
        
        self.logger.info(f"检测并处理了 {outliers_count} 个异常值")
        return result
    
    def normalize_text(self, df, text_columns):
        """文本标准化处理
        
        参数:
            df: 输入DataFrame
            text_columns: 要处理的文本列名列表
            
        返回:
            处理后的DataFrame
        """
        if df.empty:
            return df
        
        result = df.copy()
        
        for col in text_columns:
            if col not in df.columns:
                self.logger.warning(f"列 {col} 不存在")
                continue
            
            if not pd.api.types.is_string_dtype(df[col]):
                self.logger.warning(f"列 {col} 不是文本类型")
                continue
            
            # 文本处理:去除多余空格、转为小写
            result[col] = df[col].str.strip().str.lower()
            
            # 去除特殊字符
            result[col] = result[col].apply(lambda x: re.sub(r'[^\w\s]', '', str(x)) if pd.notna(x) else x)
            
            self.logger.info(f"完成列 {col} 的文本标准化处理")
        
        return result
    
    def convert_data_types(self, df, type_dict):
        """转换数据类型
        
        参数:
            df: 输入DataFrame
            type_dict: 类型转换字典,键为列名,值为目标类型
            
        返回:
            处理后的DataFrame
        """
        if df.empty:
            return df
        
        result = df.copy()
        
        for col, dtype in type_dict.items():
            if col not in df.columns:
                self.logger.warning(f"列 {col} 不存在")
                continue
            
            try:
                result[col] = result[col].astype(dtype)
                self.logger.info(f"将列 {col} 的类型转换为 {dtype}")
            except Exception as e:
                self.logger.error(f"转换列 {col} 的类型时出错: {e}")
        
        return result
    
    def parse_dates(self, df, date_columns, date_format=None):
        """解析日期列
        
        参数:
            df: 输入DataFrame
            date_columns: 日期列名列表
            date_format: 日期格式字符串(可选)
            
        返回:
            处理后的DataFrame
        """
        if df.empty:
            return df
        
        result = df.copy()
        
        for col in date_columns:
            if col not in df.columns:
                self.logger.warning(f"列 {col} 不存在")
                continue
            
            try:
                if date_format:
                    result[col] = pd.to_datetime(result[col], format=date_format)
                else:
                    result[col] = pd.to_datetime(result[col])
                
                self.logger.info(f"将列 {col} 转换为日期时间类型")
            except Exception as e:
                self.logger.error(f"转换列 {col} 为日期时间类型时出错: {e}")
        
        return result
    
    def clean_data(self, df, config=None):
        """综合数据清洗
        
        参数:
            df: 输入DataFrame
            config: 清洗配置字典
            
        返回:
            清洗后的DataFrame
        """
        if df.empty:
            return df
        
        if config is None:
            config = {}
        
        result = df.copy()
        
        # 处理缺失值
        if 'missing_values' in config:
            missing_config = config['missing_values']
            result = self.handle_missing_values(
                result, 
                strategy=missing_config.get('strategy', 'drop'),
                fill_value=missing_config.get('fill_value')
            )
        
        # 删除重复行
        if config.get('remove_duplicates', True):
            subset = config.get('duplicate_subset')
            result = self.remove_duplicates(result, subset=subset)
        
        # 处理异常值
        if 'outliers' in config:
            outlier_config = config['outliers']
            result = self.handle_outliers(
                result,
                columns=outlier_config.get('columns', []),
                method=outlier_config.get('method', 'zscore'),
                threshold=outlier_config.get('threshold', 3.0)
            )
        
        # 文本标准化
        if 'text_columns' in config:
            result = self.normalize_text(result, config['text_columns'])
        
        # 转换数据类型
        if 'type_conversions' in config:
            result = self.convert_data_types(result, config['type_conversions'])
        
        # 解析日期
        if 'date_columns' in config:
            date_config = config['date_columns']
            if isinstance(date_config, list):
                result = self.parse_dates(result, date_config)
            elif isinstance(date_config, dict):
                for col, format_str in date_config.items():
                    result = self.parse_dates(result, [col], date_format=format_str)
        
        self.logger.info(f"数据清洗完成,从 {len(df)} 行处理为 {len(result)} 行")
        return result

# 使用示例
def data_cleaning_example():
    # 创建示例数据
    data = {
        'product_name': ['iPhone 13  ', 'Samsung Galaxy', 'Xiaomi Mi 11', 'iPhone 13', None],
        'price': [5999, 4999, 3999, 5999, 2999],
        'rating': [4.8, 4.6, 4.5, 4.8, 10.0],  # 包含异常值
        'reviews_count': ['120', '98', '75', '120', '30'],  # 字符串类型
        'release_date': ['2021-09-15', '2021-08-20', '2021-03-10', '2021-09-15', '2022-01-01']
    }
    df = pd.DataFrame(data)
    
    # 创建数据清洗器
    cleaner = DataCleaner()
    
    # 配置清洗参数
    config = {
        'missing_values': {'strategy': 'drop'},
        'remove_duplicates': True,
        'outliers': {
            'columns': ['rating', 'price'],
            'method': 'zscore',
            'threshold': 2.5
        },
        'text_columns': ['product_name'],
        'type_conversions': {'reviews_count': 'int'},
        'date_columns': {'release_date': '%Y-%m-%d'}
    }
    
    # 执行数据清洗
    cleaned_df = cleaner.clean_data(df, config)
    
    print("原始数据:")
    print(df)
    print("\n清洗后的数据:")
    print(cleaned_df)
    
    return cleaned_df

if __name__ == "__main__":
    data_cleaning_example()

6.2 统计分析

统计分析用于计算数据的基本统计量和分布特征:

import pandas as pd
import numpy as np
import scipy.stats as stats
import logging

class StatisticalAnalyzer:
    """统计分析类"""
    
    def __init__(self):
        """初始化统计分析器"""
        self.logger = self._setup_logger()
    
    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('StatisticalAnalyzer')
        logger.setLevel(logging.INFO)
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        
        return logger
    
    def describe_data(self, df, include=None):
        """生成数据描述性统计
        
        参数:
            df: 输入DataFrame
            include: 包含的数据类型,默认为None(所有数值列)
            
        返回:
            描述性统计结果DataFrame
        """
        if df.empty:
            self.logger.warning("输入DataFrame为空")
            return pd.DataFrame()
        
        try:
            stats_df = df.describe(include=include, percentiles=[.1, .25, .5, .75, .9])
            self.logger.info("生成描述性统计完成")
            return stats_df
        except Exception as e:
            self.logger.error(f"生成描述性统计时出错: {e}")
            return pd.DataFrame()
    
    def correlation_analysis(self, df, method='pearson'):
        """相关性分析
        
        参数:
            df: 输入DataFrame
            method: 相关系数计算方法,可选'pearson'、'spearman'、'kendall'
            
        返回:
            相关系数矩阵DataFrame
        """
        if df.empty:
            self.logger.warning("输入DataFrame为空")
            return pd.DataFrame()
        
        # 筛选数值型列
        numeric_df = df.select_dtypes(include=['number'])
        
        if numeric_df.empty:
            self.logger.warning("没有数值型列可进行相关性分析")
            return pd.DataFrame()
        
        try:
            corr_matrix = numeric_df.corr(method=method)
            self.logger.info(f"使用 {method} 方法完成相关性分析")
            return corr_matrix
        except Exception as e:
            self.logger.error(f"计算相关系数时出错: {e}")
            return pd.DataFrame()
    
    def frequency_analysis(self, df, column, normalize=False, bins=None):
        """频率分析
        
        参数:
            df: 输入DataFrame
            column: 要分析的列名
            normalize: 是否归一化频率
            bins: 数值型数据的分箱数量
            
        返回:
            频率分析结果Series
        """
        if df.empty or column not in df.columns:
            self.logger.warning(f"输入DataFrame为空或不包含列 {column}")
            return pd.Series()
        
        try:
            # 检查列的数据类型
            if pd.api.types.is_numeric_dtype(df[column]) and bins is not None:
                # 数值型数据,进行分箱
                freq = pd.cut(df[column], bins=bins).value_counts(normalize=normalize)
                self.logger.info(f"对数值列 {column} 进行分箱频率分析,分箱数量: {bins}")
            else:
                # 分类数据,直接计算频率
                freq = df[column].value_counts(normalize=normalize)
                self.logger.info(f"对列 {column} 进行频率分析")
            
            return freq
        except Exception as e:
            self.logger.error(f"进行频率分析时出错: {e}")
            return pd.Series()
    
    def group_analysis(self, df, group_by, agg_dict):
        """分组分析
        
        参数:
            df: 输入DataFrame
            group_by: 分组列名或列名列表
            agg_dict: 聚合字典,键为列名,值为聚合函数或函数列表
            
        返回:
            分组分析结果DataFrame
        """
        if df.empty:
            self.logger.warning("输入DataFrame为空")
            return pd.DataFrame()
        
        try:
            result = df.groupby(group_by).agg(agg_dict)
            self.logger.info(f"按 {group_by} 完成分组分析")
            return result
        except Exception as e:
            self.logger.error(f"进行分组分析时出错: {e}")
            return pd.DataFrame()
    
    def time_series_analysis(self, df, date_column, value_column, freq='D'):
        """时间序列分析
        
        参数:
            df: 输入DataFrame
            date_column: 日期列名
            value_column: 值列名
            freq: 重采样频率,如'D'(天)、'W'(周)、'M'(月)
            
        返回:
            重采样后的时间序列DataFrame
        """
        if df.empty or date_column not in df.columns or value_column not in df.columns:
            self.logger.warning(f"输入DataFrame为空或缺少必要的列")
            return pd.DataFrame()
        
        try:
            # 确保日期列是datetime类型
            if not pd.api.types.is_datetime64_dtype(df[date_column]):
                df = df.copy()
                df[date_column] = pd.to_datetime(df[date_column])
            
            # 设置日期索引
            ts_df = df.set_index(date_column)
            
            # 按指定频率重采样并计算均值
            resampled = ts_df[value_column].resample(freq).mean()
            
            self.logger.info(f"完成时间序列分析,重采样频率: {freq}")
            return resampled.reset_index()
        except Exception as e:
            self.logger.error(f"进行时间序列分析时出错: {e}")
            return pd.DataFrame()
    
    def hypothesis_testing(self, df, column1, column2=None, test_type='ttest'):
        """假设检验
        
        参数:
            df: 输入DataFrame
            column1: 第一个数据列名
            column2: 第二个数据列名(对于双样本检验)
            test_type: 检验类型,可选'ttest'、'anova'、'chi2'等
            
        返回:
            检验结果字典
        """
        if df.empty or column1 not in df.columns:
            self.logger.warning(f"输入DataFrame为空或不包含列 {column1}")
            return {}
        
        try:
            result = {}
            
            if test_type == 'ttest':
                # t检验
                if column2 and column2 in df.columns:
                    # 双样本t检验
                    t_stat, p_value = stats.ttest_ind(
                        df[column1].dropna(), 
                        df[column2].dropna(),
                        equal_var=False  # 不假设方差相等
                    )
                    result = {
                        'test': 'Independent Samples t-test',
                        't_statistic': t_stat,
                        'p_value': p_value,
                        'significant': p_value < 0.05
                    }
                    self.logger.info(f"完成独立样本t检验: {column1} vs {column2}")
                else:
                    # 单样本t检验(与0比较)
                    t_stat, p_value = stats.ttest_1samp(df[column1].dropna(), 0)
                    result = {
                        'test': 'One Sample t-test',
                        't_statistic': t_stat,
                        'p_value': p_value,
                        'significant': p_value < 0.05
                    }
                    self.logger.info(f"完成单样本t检验: {column1}")
            
            elif test_type == 'chi2' and column2 and column2 in df.columns:
                # 卡方检验(分类变量)
                contingency_table = pd.crosstab(df[column1], df[column2])
                chi2, p_value, dof, expected = stats.chi2_contingency(contingency_table)
                result = {
                    'test': 'Chi-square Test',
                    'chi2_statistic': chi2,
                    'p_value': p_value,
                    'degrees_of_freedom': dof,
                    'significant': p_value < 0.05
                }
                self.logger.info(f"完成卡方检验: {column1} vs {column2}")
            
            else:
                self.logger.warning(f"不支持的检验类型: {test_type}")
            
            return result
        
        except Exception as e:
            self.logger.error(f"进行假设检验时出错: {e}")
            return {'error': str(e)}

# 使用示例
def statistical_analysis_example():
    # 创建示例数据
    np.random.seed(42)
    n_samples = 200
    
    # 生成特征
    X = np.random.randn(n_samples, 3)  # 3个特征
    
    # 生成分类目标变量
    y_class = (X[:, 0] + X[:, 1] * 0.5 + np.random.randn(n_samples) * 0.1) > 0
    
    # 生成回归目标变量
    y_reg = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
    
    # 创建DataFrame
    data = pd.DataFrame(
        X, 
        columns=['feature_1', 'feature_2', 'feature_3']
    )
    data['target_class'] = y_class.astype(int)
    data['target_reg'] = y_reg
    
    # 添加一些派生列
    data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
    data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
    data['sales_per_customer'] = data['target_reg'] / np.random.poisson(10, n_samples)
    
    # 创建统计分析器
    analyzer = StatisticalAnalyzer()
    
    # 描述性统计
    desc_stats = analyzer.describe_data(data)
    print("描述性统计:")
    print(desc_stats)
    
    # 相关性分析
    corr_matrix = analyzer.correlation_analysis(data)
    print("\n相关性矩阵:")
    print(corr_matrix)
    
    # 频率分析
    category_freq = analyzer.frequency_analysis(data, 'feature_1', normalize=True)
    print("\n特征1频率分析:")
    print(category_freq)
    
    # 分组分析
    group_result = analyzer.group_analysis(
        data, 
        'feature_1', 
        {'target_class': ['mean', 'sum'], 'feature_2': 'mean', 'feature_3': 'mean'}
    )
    print("\n分组分析结果:")
    print(group_result)
    
    # 时间序列分析
    ts_result = analyzer.time_series_analysis(data, 'feature_1', 'target_reg', freq='W')
    print("\n时间序列分析结果(周均值):")
    print(ts_result.head())
    
    # 假设检验
    test_result = analyzer.hypothesis_testing(data, 'feature_1', test_type='ttest')
    print("\n假设检验结果:")
    print(test_result)
    
    return {
        'desc_stats': desc_stats,
        'corr_matrix': corr_matrix,
        'category_freq': category_freq,
        'group_result': group_result,
        'ts_result': ts_result,
        'test_result': test_result
    }

if __name__ == "__main__":
    statistical_analysis_example()

6.3 数据挖掘

数据挖掘是从大量数据中发现模式和关系的过程,包括聚类分析、分类和回归模型等:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans, DBSCAN
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    mean_squared_error, r2_score, silhouette_score
)
import logging

class DataMiner:
    """数据挖掘类"""
    
    def __init__(self):
        """初始化数据挖掘器"""
        self.logger = self._setup_logger()
        self.models = {}  # 存储训练好的模型
    
    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('DataMiner')
        logger.setLevel(logging.INFO)
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        
        return logger
    
    def preprocess_data(self, df, scale_method='standard', categorical_cols=None):
        """数据预处理
        
        参数:
            df: 输入DataFrame
            scale_method: 缩放方法,可选'standard'、'minmax'
            categorical_cols: 分类变量列名列表
            
        返回:
            预处理后的DataFrame和预处理器
        """
        if df.empty:
            self.logger.warning("输入DataFrame为空")
            return df, None
        
        # 处理缺失值
        df_clean = df.dropna()
        if len(df_clean) < len(df):
            self.logger.info(f"删除了 {len(df) - len(df_clean)} 行含有缺失值的数据")
        
        # 处理分类变量
        if categorical_cols:
            df_encoded = pd.get_dummies(df_clean, columns=categorical_cols)
            self.logger.info(f"对 {len(categorical_cols)} 个分类变量进行了独热编码")
        else:
            df_encoded = df_clean
        
        # 数值变量缩放
        numeric_cols = df_encoded.select_dtypes(include=['number']).columns
        
        if scale_method == 'standard':
            scaler = StandardScaler()
            self.logger.info("使用StandardScaler进行标准化")
        elif scale_method == 'minmax':
            scaler = MinMaxScaler()
            self.logger.info("使用MinMaxScaler进行归一化")
        else:
            self.logger.warning(f"未知的缩放方法: {scale_method},不进行缩放")
            return df_encoded, None
        
        if len(numeric_cols) > 0:
            df_encoded[numeric_cols] = scaler.fit_transform(df_encoded[numeric_cols])
            self.logger.info(f"对 {len(numeric_cols)} 个数值变量进行了缩放")
        
        return df_encoded, scaler
    
    def reduce_dimensions(self, df, n_components=2, method='pca'):
        """降维
        
        参数:
            df: 输入DataFrame
            n_components: 目标维度
            method: 降维方法,目前支持'pca'
            
        返回:
            降维后的DataFrame和降维器
        """
        if df.empty:
            self.logger.warning("输入DataFrame为空")
            return df, None
        
        # 确保数据为数值型
        numeric_df = df.select_dtypes(include=['number'])
        
        if numeric_df.empty:
            self.logger.warning("没有数值型列可进行降维")
            return df, None
        
        try:
            if method == 'pca':
                reducer = PCA(n_components=n_components)
                reduced_data = reducer.fit_transform(numeric_df)
                
                # 创建包含降维结果的DataFrame
                result_df = pd.DataFrame(
                    reduced_data,
                    columns=[f'PC{i+1}' for i in range(n_components)],
                    index=df.index
                )
                
                # 计算解释方差比例
                explained_variance = reducer.explained_variance_ratio_.sum()
                self.logger.info(f"PCA降维完成,保留了 {n_components} 个主成分,解释了 {explained_variance:.2%} 的方差")
                
                return result_df, reducer
            else:
                self.logger.warning(f"不支持的降维方法: {method}")
                return df, None
        except Exception as e:
            self.logger.error(f"降维过程中出错: {e}")
            return df, None
    
    def cluster_data(self, df, method='kmeans', n_clusters=3, eps=0.5, min_samples=5):
        """聚类分析
        
        参数:
            df: 输入DataFrame
            method: 聚类方法,可选'kmeans'、'dbscan'
            n_clusters: KMeans的簇数量
            eps: DBSCAN的邻域半径
            min_samples: DBSCAN的最小样本数
            
        返回:
            带有聚类标签的DataFrame和聚类器
        """
        if df.empty:
            self.logger.warning("输入DataFrame为空")
            return df, None
        
        # 确保数据为数值型
        numeric_df = df.select_dtypes(include=['number'])
        
        if numeric_df.empty:
            self.logger.warning("没有数值型列可进行聚类")
            return df, None
        
        try:
            result_df = df.copy()
            
            if method == 'kmeans':
                # K-means聚类
                clusterer = KMeans(n_clusters=n_clusters, random_state=42)
                labels = clusterer.fit_predict(numeric_df)
                
                # 计算轮廓系数
                if n_clusters > 1 and len(numeric_df) > n_clusters:
                    silhouette = silhouette_score(numeric_df, labels)
                    self.logger.info(f"K-means聚类完成,轮廓系数: {silhouette:.4f}")
                else:
                    self.logger.info("K-means聚类完成,但无法计算轮廓系数(簇数过少或数据量不足)")
                
            elif method == 'dbscan':
                # DBSCAN聚类
                clusterer = DBSCAN(eps=eps, min_samples=min_samples)
                labels = clusterer.fit_predict(numeric_df)
                
                # 计算聚类统计信息
                n_clusters_found = len(set(labels)) - (1 if -1 in labels else 0)
                n_noise = list(labels).count(-1)
                self.logger.info(f"DBSCAN聚类完成,发现 {n_clusters_found} 个簇,{n_noise} 个噪声点")
                
            else:
                self.logger.warning(f"不支持的聚类方法: {method}")
                return df, None
            
            # 添加聚类标签
            result_df['cluster'] = labels
            
            return result_df, clusterer
        
        except Exception as e:
            self.logger.error(f"聚类过程中出错: {e}")
            return df, None
    
    def train_classifier(self, df, target_col, feature_cols=None, model_type='random_forest', test_size=0.2):
        """训练分类模型
        
        参数:
            df: 输入DataFrame
            target_col: 目标变量列名
            feature_cols: 特征列名列表,默认使用所有数值列
            model_type: 模型类型,可选'logistic'、'random_forest'
            test_size: 测试集比例
            
        返回:
            模型评估指标字典和训练好的模型
        """
        if df.empty or target_col not in df.columns:
            self.logger.warning(f"输入DataFrame为空或不包含目标列 {target_col}")
            return {}, None
        
        try:
            # 准备特征和目标变量
            if feature_cols is None:
                # 使用除目标列外的所有数值列作为特征
                feature_cols = df.select_dtypes(include=['number']).columns.tolist()
                if target_col in feature_cols:
                    feature_cols.remove(target_col)
            
            if not feature_cols:
                self.logger.warning("没有可用的特征列")
                return {}, None
            
            X = df[feature_cols]
            y = df[target_col]
            
            # 划分训练集和测试集
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)
            
            # 训练模型
            if model_type == 'logistic':
                model = LogisticRegression(max_iter=1000, random_state=42)
                model_name = 'Logistic Regression'
            elif model_type == 'random_forest':
                model = RandomForestClassifier(n_estimators=100, random_state=42)
                model_name = 'Random Forest'
            else:
                self.logger.warning(f"不支持的分类模型类型: {model_type}")
                return {}, None
            
            model.fit(X_train, y_train)
            
            # 在测试集上评估
            y_pred = model.predict(X_test)
            
            # 计算评估指标
            metrics = {
                'accuracy': accuracy_score(y_test, y_pred),
                'precision': precision_score(y_test, y_pred, average='weighted'),
                'recall': recall_score(y_test, y_pred, average='weighted'),
                'f1': f1_score(y_test, y_pred, average='weighted')
            }
            
            # 交叉验证
            cv_scores = cross_val_score(model, X, y, cv=5)
            metrics['cv_accuracy_mean'] = cv_scores.mean()
            metrics['cv_accuracy_std'] = cv_scores.std()
            
            self.logger.info(f"{model_name}分类模型训练完成,准确率: {metrics['accuracy']:.4f}")
            
            # 存储模型
            model_id = f"{model_type}_classifier_{target_col}"
            self.models[model_id] = {
                'model': model,
                'feature_cols': feature_cols,
                'target_col': target_col,
                'metrics': metrics
            }
            
            return metrics, model
        
        except Exception as e:
            self.logger.error(f"训练分类模型时出错: {e}")
            return {}, None
    
    def train_regressor(self, df, target_col, feature_cols=None, model_type='linear', test_size=0.2):
        """训练回归模型
        
        参数:
            df: 输入DataFrame
            target_col: 目标变量列名
            feature_cols: 特征列名列表,默认使用所有数值列
            model_type: 模型类型,可选'linear'、'random_forest'
            test_size: 测试集比例
            
        返回:
            模型评估指标字典和训练好的模型
        """
        if df.empty or target_col not in df.columns:
            self.logger.warning(f"输入DataFrame为空或不包含目标列 {target_col}")
            return {}, None
        
        try:
            # 准备特征和目标变量
            if feature_cols is None:
                # 使用除目标列外的所有数值列作为特征
                feature_cols = df.select_dtypes(include=['number']).columns.tolist()
                if target_col in feature_cols:
                    feature_cols.remove(target_col)
            
            if not feature_cols:
                self.logger.warning("没有可用的特征列")
                return {}, None
            
            X = df[feature_cols]
            y = df[target_col]
            
            # 划分训练集和测试集
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)
            
            # 训练模型
            if model_type == 'linear':
                model = LinearRegression()
                model_name = 'Linear Regression'
            elif model_type == 'random_forest':
                model = RandomForestRegressor(n_estimators=100, random_state=42)
                model_name = 'Random Forest'
            else:
                self.logger.warning(f"不支持的回归模型类型: {model_type}")
                return {}, None
            
            model.fit(X_train, y_train)
            
            # 在测试集上评估
            y_pred = model.predict(X_test)
            
            # 计算评估指标
            metrics = {
                'mse': mean_squared_error(y_test, y_pred),
                'rmse': np.sqrt(mean_squared_error(y_test, y_pred)),
                'r2': r2_score(y_test, y_pred)
            }
            
            # 交叉验证
            cv_scores = cross_val_score(model, X, y, cv=5, scoring='r2')
            metrics['cv_r2_mean'] = cv_scores.mean()
            metrics['cv_r2_std'] = cv_scores.std()
            
            self.logger.info(f"{model_name}回归模型训练完成,R²: {metrics['r2']:.4f}")
            
            # 存储模型
            model_id = f"{model_type}_regressor_{target_col}"
            self.models[model_id] = {
                'model': model,
                'feature_cols': feature_cols,
                'target_col': target_col,
                'metrics': metrics
            }
            
            return metrics, model
        
        except Exception as e:
            self.logger.error(f"训练回归模型时出错: {e}")
            return {}, None
    
    def predict(self, model_id, new_data):
        """使用训练好的模型进行预测
        
        参数:
            model_id: 模型ID
            new_data: 新数据DataFrame
            
        返回:
            预测结果
        """
        if model_id not in self.models:
            self.logger.warning(f"模型ID {model_id} 不存在")
            return None
        
        model_info = self.models[model_id]
        model = model_info['model']
        feature_cols = model_info['feature_cols']
        
        # 检查新数据是否包含所有特征列
        missing_cols = [col for col in feature_cols if col not in new_data.columns]
        if missing_cols:
            self.logger.warning(f"新数据缺少特征列: {missing_cols}")
            return None
        
        try:
            # 提取特征
            X_new = new_data[feature_cols]
            
            # 进行预测
            predictions = model.predict(X_new)
            
            self.logger.info(f"使用模型 {model_id} 完成预测,预测样本数: {len(predictions)}")
            
            return predictions
        
        except Exception as e:
            self.logger.error(f"预测过程中出错: {e}")
            return None
    
    def get_feature_importance(self, model_id):
        """获取特征重要性
        
        参数:
            model_id: 模型ID
            
        返回:
            特征重要性DataFrame
        """
        if model_id not in self.models:
            self.logger.warning(f"模型ID {model_id} 不存在")
            return pd.DataFrame()
        
        model_info = self.models[model_id]
        model = model_info['model']
        feature_cols = model_info['feature_cols']
        
        # 检查模型是否有feature_importances_属性
        if not hasattr(model, 'feature_importances_'):
            self.logger.warning(f"模型 {model_id} 不支持特征重要性分析")
            
            # 对于线性模型,可以使用系数作为特征重要性
            if hasattr(model, 'coef_'):
                importances = np.abs(model.coef_)
                if importances.ndim > 1:
                    importances = importances.mean(axis=0)
            else:
                return pd.DataFrame()
        else:
            importances = model.feature_importances_
        
        # 创建特征重要性DataFrame
        importance_df = pd.DataFrame({
            'feature': feature_cols,
            'importance': importances
        })
        
        # 按重要性降序排序
        importance_df = importance_df.sort_values('importance', ascending=False)
        
        self.logger.info(f"获取模型 {model_id} 的特征重要性")
        
        return importance_df

# 使用示例
def data_mining_example():
    # 创建示例数据
    np.random.seed(42)
    n_samples = 200
    
    # 生成特征
    X = np.random.randn(n_samples, 5)  # 5个特征
    
    # 生成分类目标变量
    y_class = (X[:, 0] + X[:, 1] * 0.5 + np.random.randn(n_samples) * 0.1) > 0
    
    # 生成回归目标变量
    y_reg = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
    
    # 创建DataFrame
    data = pd.DataFrame(
        X, 
        columns=[f'feature_{i+1}' for i in range(5)]
    )
    data['target_class'] = y_class.astype(int)
    data['target_reg'] = y_reg
    
    # 添加一些派生列
    data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
    data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
    data['sales_per_customer'] = data['target_reg'] / np.random.poisson(10, n_samples)
    
    # 创建数据挖掘器
    miner = DataMiner()
    
    # 数据预处理
    print("数据预处理...")
    data_processed, scaler = miner.preprocess_data(
        data,
        scale_method='standard',
        categorical_cols=['month', 'day_of_week']
    )
    
    # 降维分析
    print("\n降维分析...")
    data_reduced, pca = miner.reduce_dimensions(
        data_processed.drop(['target_class', 'target_reg'], axis=1),
        n_components=2
    )
    
    # 聚类分析
    print("\n聚类分析...")
    data_clustered, kmeans = miner.cluster_data(
        data_reduced,
        method='kmeans',
        n_clusters=3
    )
    
    # 分类模型
    print("\n训练分类模型...")
    class_metrics, classifier = miner.train_classifier(
        data_processed,
        target_col='target_class',
        model_type='random_forest'
    )
    print(f"分类模型评估指标: {class_metrics}")
    
    # 回归模型
    print("\n训练回归模型...")
    reg_metrics, regressor = miner.train_regressor(
        data_processed,
        target_col='target_reg',
        model_type='random_forest'
    )
    print(f"回归模型评估指标: {reg_metrics}")
    
    # 特征重要性
    print("\n特征重要性分析...")
    importance = miner.get_feature_importance('random_forest_regressor_target_reg')
    print(importance)
    
    return {
        'data_processed': data_processed,
        'data_reduced': data_reduced,
        'data_clustered': data_clustered,
        'class_metrics': class_metrics,
        'reg_metrics': reg_metrics,
        'feature_importance': importance
    }

if __name__ == "__main__":
    data_mining_example()

6.4 特征工程

特征工程是数据分析和机器学习中至关重要的一步,它可以显著提高模型性能:

import pandas as pd
import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from sklearn.feature_selection import SelectKBest, f_regression, mutual_info_regression
import logging

class FeatureEngineer:
    """特征工程类"""
    
    def __init__(self):
        """初始化特征工程器"""
        self.logger = self._setup_logger()
    
    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('FeatureEngineer')
        logger.setLevel(logging.INFO)
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        
        return logger
    
    def create_polynomial_features(self, df, feature_cols, degree=2, include_bias=False):
        """创建多项式特征
        
        参数:
            df: 输入DataFrame
            feature_cols: 特征列名列表
            degree: 多项式次数
            include_bias: 是否包含偏置项
            
        返回:
            包含多项式特征的DataFrame
        """
        if df.empty or not feature_cols:
            self.logger.warning("输入DataFrame为空或未指定特征列")
            return df
        
        try:
            # 提取特征
            X = df[feature_cols].values
            
            # 创建多项式特征
            poly = PolynomialFeatures(degree=degree, include_bias=include_bias)
            poly_features = poly.fit_transform(X)
            
            # 创建特征名称
            feature_names = poly.get_feature_names_out(feature_cols)
            
            # 创建包含多项式特征的DataFrame
            poly_df = pd.DataFrame(poly_features, columns=feature_names, index=df.index)
            
            # 合并原始DataFrame和多项式特征
            result_df = pd.concat([df.drop(feature_cols, axis=1), poly_df], axis=1)
            
            self.logger.info(f"创建了 {poly_features.shape[1]} 个多项式特征,次数: {degree}")
            
            return result_df
        
        except Exception as e:
            self.logger.error(f"创建多项式特征时出错: {e}")
            return df
    
    def create_interaction_features(self, df, feature_cols):
        """创建交互特征
        
        参数:
            df: 输入DataFrame
            feature_cols: 特征列名列表
            
        返回:
            包含交互特征的DataFrame
        """
        if df.empty or len(feature_cols) < 2:
            self.logger.warning("输入DataFrame为空或特征列不足")
            return df
        
        try:
            result_df = df.copy()
            interaction_count = 0
            
            # 创建两两特征的交互项
            for i in range(len(feature_cols)):
                for j in range(i+1, len(feature_cols)):
                    col1 = feature_cols[i]
                    col2 = feature_cols[j]
                    
                    # 创建交互特征
                    interaction_name = f"{col1}_x_{col2}"
                    result_df[interaction_name] = df[col1] * df[col2]
                    interaction_count += 1
            
            self.logger.info(f"创建了 {interaction_count} 个交互特征")
            
            return result_df
        
        except Exception as e:
            self.logger.error(f"创建交互特征时出错: {e}")
            return df
    
    def create_binning_features(self, df, feature_col, bins=5, strategy='uniform'):
        """创建分箱特征
        
        参数:
            df: 输入DataFrame
            feature_col: 要分箱的特征列名
            bins: 分箱数量或边界列表
            strategy: 分箱策略,可选'uniform'、'quantile'
            
        返回:
            包含分箱特征的DataFrame
        """
        if df.empty or feature_col not in df.columns:
            self.logger.warning(f"输入DataFrame为空或不包含列 {feature_col}")
            return df
        
        try:
            result_df = df.copy()
            
            # 确定分箱边界
            if isinstance(bins, int):
                if strategy == 'uniform':
                    # 均匀分箱
                    bin_edges = np.linspace(
                        df[feature_col].min(),
                        df[feature_col].max(),
                        bins + 1
                    )
                elif strategy == 'quantile':
                    # 分位数分箱
                    bin_edges = np.percentile(
                        df[feature_col],
                        np.linspace(0, 100, bins + 1)
                    )
                else:
                    self.logger.warning(f"不支持的分箱策略: {strategy}")
                    return df
            else:
                # 使用指定的分箱边界
                bin_edges = bins
            
            # 创建分箱特征
            binned_feature = pd.cut(
                df[feature_col],
                bins=bin_edges,
                labels=False,
                include_lowest=True
            )
            
            # 添加分箱特征
            result_df[f"{feature_col}_bin"] = binned_feature
            
            # 创建独热编码的分箱特征
            bin_dummies = pd.get_dummies(
                binned_feature,
                prefix=f"{feature_col}_bin",
                prefix_sep="_"
            )
            
            # 合并结果
            result_df = pd.concat([result_df, bin_dummies], axis=1)
            
            self.logger.info(f"对特征 {feature_col} 创建了 {len(bin_edges)-1} 个分箱特征")
            
            return result_df
        
        except Exception as e:
            self.logger.error(f"创建分箱特征时出错: {e}")
            return df
    
    def select_best_features(self, df, feature_cols, target_col, k=5, method='f_regression'):
        """选择最佳特征
        
        参数:
            df: 输入DataFrame
            feature_cols: 特征列名列表
            target_col: 目标变量列名
            k: 选择的特征数量
            method: 特征选择方法,可选'f_regression'、'mutual_info'
            
        返回:
            包含选定特征的DataFrame和特征得分
        """
        if df.empty or not feature_cols or target_col not in df.columns:
            self.logger.warning("输入DataFrame为空或未指定特征列或目标列")
            return df, {}
        
        try:
            # 提取特征和目标变量
            X = df[feature_cols]
            y = df[target_col]
            
            # 选择特征选择器
            if method == 'f_regression':
                selector = SelectKBest(score_func=f_regression, k=k)
                method_name = "F回归"
            elif method == 'mutual_info':
                selector = SelectKBest(score_func=mutual_info_regression, k=k)
                method_name = "互信息"
            else:
                self.logger.warning(f"不支持的特征选择方法: {method}")
                return df, {}
            
            # 拟合选择器
            selector.fit(X, y)
            
            # 获取选定的特征索引
            selected_indices = selector.get_support(indices=True)
            selected_features = [feature_cols[i] for i in selected_indices]
            
            # 创建特征得分字典
            feature_scores = dict(zip(feature_cols, selector.scores_))
            
            # 创建包含选定特征的DataFrame
            result_df = df.copy()
            dropped_features = [col for col in feature_cols if col not in selected_features]
            if dropped_features:
                result_df = result_df.drop(dropped_features, axis=1)
            
            self.logger.info(f"使用 {method_name} 方法选择了 {len(selected_features)} 个最佳特征")
            
            return result_df, feature_scores
        
        except Exception as e:
            self.logger.error(f"选择最佳特征时出错: {e}")
            return df, {}

# 使用示例
def feature_engineering_example():
    # 创建示例数据
    np.random.seed(42)
    n_samples = 200
    
    # 生成特征
    X = np.random.randn(n_samples, 3)  # 3个特征
    
    # 生成目标变量(回归)
    y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
    
    # 创建DataFrame
    data = pd.DataFrame(
        X, 
        columns=['feature_1', 'feature_2', 'feature_3']
    )
    data['target'] = y
    
    # 创建特征工程器
    engineer = FeatureEngineer()
    
    # 创建多项式特征
    print("创建多项式特征...")
    poly_data = engineer.create_polynomial_features(
        data,
        ['feature_1', 'feature_2', 'feature_3'],
        degree=2
    )
    print(f"多项式特征后的列: {poly_data.columns.tolist()}")
    
    # 创建交互特征
    print("\n创建交互特征...")
    interaction_data = engineer.create_interaction_features(
        data,
        ['feature_1', 'feature_2', 'feature_3']
    )
    print(f"交互特征后的列: {interaction_data.columns.tolist()}")
    
    # 创建分箱特征
    print("\n创建分箱特征...")
    binned_data = engineer.create_binning_features(
        data,
        'feature_1',
        bins=5,
        strategy='quantile'
    )
    print(f"分箱特征后的列: {binned_data.columns.tolist()}")
    
    # 特征选择
    print("\n特征选择...")
    # 首先创建更多特征用于选择
    combined_data = engineer.create_polynomial_features(
        data,
        ['feature_1', 'feature_2', 'feature_3'],
        degree=2
    )
    
    # 选择最佳特征
    selected_data, feature_scores = engineer.select_best_features(
        combined_data,
        [col for col in combined_data.columns if col != 'target'],
        'target',
        k=5,
        method='f_regression'
    )
    
    print("特征得分:")
    for feature, score in sorted(feature_scores.items(), key=lambda x: x[1], reverse=True):
        print(f"{feature}: {score:.4f}")
    
    print(f"\n选择的特征: {[col for col in selected_data.columns if col != 'target']}")
    
    return {
        'original_data': data,
        'poly_data': poly_data,
        'interaction_data': interaction_data,
        'binned_data': binned_data,
        'selected_data': selected_data,
        'feature_scores': feature_scores
    }

if __name__ == "__main__":
    feature_engineering_example()

6.5 数据分析模块集成

以下是如何将数据清洗、统计分析、数据挖掘和特征工程组件集成到一个完整的数据分析流程中:

def complete_data_analysis_pipeline(data, config=None):
    """完整的数据分析流程
    
    参数:
        data: 输入DataFrame
        config: 配置字典
        
    返回:
        分析结果字典
    """
    if config is None:
        config = {}
    
    results = {'original_data': data}
    
    # 1. 数据清洗
    print("1. 执行数据清洗...")
    cleaner = DataCleaner()
    clean_config = config.get('cleaning', {})
    cleaned_data = cleaner.clean_data(data, clean_config)
    results['cleaned_data'] = cleaned_data
    
    # 2. 统计分析
    print("\n2. 执行统计分析...")
    analyzer = StatisticalAnalyzer()
    
    # 描述性统计
    desc_stats = analyzer.describe_data(cleaned_data)
    results['descriptive_stats'] = desc_stats
    
    # 相关性分析
    corr_matrix = analyzer.correlation_analysis(cleaned_data)
    results['correlation_matrix'] = corr_matrix
    
    # 3. 特征工程
    print("\n3. 执行特征工程...")
    engineer = FeatureEngineer()
    feature_config = config.get('feature_engineering', {})
    
    engineered_data = cleaned_data.copy()
    
    # 应用多项式特征
    if 'polynomial' in feature_config:
        poly_config = feature_config['polynomial']
        engineered_data = engineer.create_polynomial_features(
            engineered_data,
            poly_config.get('features', []),
            degree=poly_config.get('degree', 2)
        )
    
    # 应用交互特征
    if 'interaction' in feature_config:
        interaction_config = feature_config['interaction']
        engineered_data = engineer.create_interaction_features(
            engineered_data,
            interaction_config.get('features', [])
        )
    
    # 应用分箱特征
    if 'binning' in feature_config:
        for bin_config in feature_config['binning']:
            engineered_data = engineer.create_binning_features(
                engineered_data,
                bin_config.get('feature'),
                bins=bin_config.get('bins', 5),
                strategy=bin_config.get('strategy', 'uniform')
            )
    
    results['engineered_data'] = engineered_data
    
    # 4. 数据挖掘
    print("\n4. 执行数据挖掘...")
    miner = DataMiner()
    mining_config = config.get('mining', {})
    
    # 数据预处理
    processed_data, scaler = miner.preprocess_data(
        engineered_data,
        scale_method=mining_config.get('scale_method', 'standard'),
        categorical_cols=mining_config.get('categorical_cols', [])
    )
    results['processed_data'] = processed_data
    
    # 降维分析
    if 'dimensionality_reduction' in mining_config:
        dr_config = mining_config['dimensionality_reduction']
        reduced_data, reducer = miner.reduce_dimensions(
            processed_data,
            n_components=dr_config.get('n_components', 2),
            method=dr_config.get('method', 'pca')
        )
        results['reduced_data'] = reduced_data
    
    # 聚类分析
    if 'clustering' in mining_config:
        cluster_config = mining_config['clustering']
        data_to_cluster = results.get('reduced_data', processed_data)
        clustered_data, clusterer = miner.cluster_data(
            data_to_cluster,
            method=cluster_config.get('method', 'kmeans'),
            n_clusters=cluster_config.get('n_clusters', 3)
        )
        results['clustered_data'] = clustered_data
    
    # 模型训练
    if 'models' in mining_config:
        models_results = {}
        
        for model_config in mining_config['models']:
            model_type = model_config.get('type')
            target = model_config.get('target')
            features = model_config.get('features')
            
            if model_type == 'classifier':
                metrics, model = miner.train_classifier(
                    processed_data,
                    target_col=target,
                    feature_cols=features,
                    model_type=model_config.get('algorithm', 'random_forest')
                )
                models_results[f'classifier_{target}'] = {
                    'metrics': metrics,
                    'model_id': f"{model_config.get('algorithm', 'random_forest')}_classifier_{target}"
                }
                
            elif model_type == 'regressor':
                metrics, model = miner.train_regressor(
                    processed_data,
                    target_col=target,
                    feature_cols=features,
                    model_type=model_config.get('algorithm', 'random_forest')
                )
                models_results[f'regressor_{target}'] = {
                    'metrics': metrics,
                    'model_id': f"{model_config.get('algorithm', 'random_forest')}_regressor_{target}"
                }
        
        results['models'] = models_results
    
    print("\n数据分析流程完成!")
    return results

# 使用示例
def data_analysis_example():
    # 创建示例数据
    np.random.seed(42)
    n_samples = 500
    
    # 生成特征
    X = np.random.randn(n_samples, 4)  # 4个特征
    
    # 生成分类目标变量
    y_class = (X[:, 0] + X[:, 1] * 0.5 + np.random.randn(n_samples) * 0.1) > 0
    
    # 生成回归目标变量
    y_reg = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
    
    # 创建DataFrame
    data = pd.DataFrame(
        X, 
        columns=[f'feature_{i+1}' for i in range(4)]
    )
    data['category'] = np.random.choice(['A', 'B', 'C', 'D'], n_samples)
    data['target_class'] = y_class.astype(int)
    data['target_reg'] = y_reg
    
    # 添加一些派生列
    data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
    data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
    data['sales_per_customer'] = data['target_reg'] / np.random.poisson(10, n_samples)
    
    # 配置分析流程
    config = {
        'cleaning': {
            'missing_values': {'strategy': 'drop'},
            'remove_duplicates': True,
            'outliers': {
                'columns': ['feature_1', 'feature_2', 'feature_3', 'feature_4'],
                'method': 'zscore',
                'threshold': 3.0
            },
            'text_columns': [],
            'type_conversions': {},
            'date_columns': {}
        },
        'feature_engineering': {
            'polynomial': {
                'features': ['feature_1', 'feature_2'],
                'degree': 2
            },
            'interaction': {
                'features': ['feature_1', 'feature_2', 'feature_3']
            },
            'binning': [
                {
                    'feature': 'feature_4',
                    'bins': 5,
                    'strategy': 'quantile'
                }
            ]
        },
        'mining': {
            'scale_method': 'standard',
            'categorical_cols': ['category'],
            'dimensionality_reduction': {
                'n_components': 2,
                'method': 'pca'
            },
            'clustering': {
                'method': 'kmeans',
                'n_clusters': 3
            },
            'models': [
                {
                    'type': 'classifier',
                    'target': 'target_class',
                    'algorithm': 'random_forest'
                },
                {
                    'type': 'regressor',
                    'target': 'target_reg',
                    'algorithm': 'random_forest'
                }
            ]
        }
    }
    
    # 执行分析流程
    results = complete_data_analysis_pipeline(data, config)
    
    # 打印部分结果
    print("\n描述性统计:")
    print(results['descriptive_stats'])
    
    print("\n模型性能:")
    for model_name, model_info in results['models'].items():
        print(f"{model_name}: {model_info['metrics']}")
    
    return results

if __name__ == "__main__":
    data_analysis_example()

7. 数据可视化模块

数据可视化是将数据转化为图形表示的过程,通过视觉元素如图表、图形和地图,使复杂数据更容易理解和分析。

7.1 静态可视化

静态可视化是指生成不可交互的图表,主要使用Matplotlib和Seaborn库:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.ticker as ticker
from matplotlib.colors import LinearSegmentedColormap
import logging
from pathlib import Path

class StaticVisualizer:
    """静态可视化类"""
    
    def __init__(self, output_dir='visualizations'):
        """初始化可视化器
        
        参数:
            output_dir: 输出目录
        """
        self.output_dir = output_dir
        self.logger = self._setup_logger()
        self._setup_style()
        self._ensure_output_dir()
    
    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('StaticVisualizer')
        logger.setLevel(logging.INFO)
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        
        return logger
    
    def _setup_style(self):
        """设置可视化样式"""
        # 设置Seaborn样式
        sns.set(style="whitegrid")
        
        # 设置Matplotlib参数
        plt.rcParams['figure.figsize'] = (10, 6)
        plt.rcParams['font.size'] = 12
        plt.rcParams['axes.labelsize'] = 14
        plt.rcParams['axes.titlesize'] = 16
        plt.rcParams['xtick.labelsize'] = 12
        plt.rcParams['ytick.labelsize'] = 12
        plt.rcParams['legend.fontsize'] = 12
        plt.rcParams['figure.titlesize'] = 20
    
    def _ensure_output_dir(self):
        """确保输出目录存在"""
        Path(self.output_dir).mkdir(parents=True, exist_ok=True)
        self.logger.info(f"输出目录: {self.output_dir}")
    
    def save_figure(self, fig, filename, dpi=300):
        """保存图表
        
        参数:
            fig: 图表对象
            filename: 文件名
            dpi: 分辨率
        """
        filepath = Path(self.output_dir) / filename
        fig.savefig(filepath, dpi=dpi, bbox_inches='tight')
        self.logger.info(f"图表已保存: {filepath}")
        
        return filepath
    
    def plot_bar_chart(self, data, x, y, title=None, xlabel=None, ylabel=None, 
                      color='skyblue', figsize=(10, 6), save_as=None, **kwargs):
        """绘制条形图
        
        参数:
            data: DataFrame
            x: x轴列名
            y: y轴列名
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            color: 条形颜色
            figsize: 图表大小
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            matplotlib图表对象
        """
        try:
            # 创建图表
            fig, ax = plt.subplots(figsize=figsize)
            
            # 绘制条形图
            sns.barplot(x=x, y=y, data=data, color=color, ax=ax, **kwargs)
            
            # 设置标题和标签
            if title:
                ax.set_title(title)
            if xlabel:
                ax.set_xlabel(xlabel)
            if ylabel:
                ax.set_ylabel(ylabel)
            
            # 格式化y轴标签
            ax.yaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))
            
            # 添加数值标签
            for p in ax.patches:
                ax.annotate(f'{p.get_height():,.0f}', 
                           (p.get_x() + p.get_width() / 2., p.get_height()),
                           ha='center', va='bottom', fontsize=10)
            
            # 调整布局
            plt.tight_layout()
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制条形图时出错: {e}")
            return None
    
    def plot_line_chart(self, data, x, y, title=None, xlabel=None, ylabel=None,
                       color='royalblue', figsize=(12, 6), save_as=None, **kwargs):
        """绘制折线图
        
        参数:
            data: DataFrame
            x: x轴列名
            y: y轴列名或列名列表
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            color: 线条颜色或颜色列表
            figsize: 图表大小
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            matplotlib图表对象
        """
        try:
            # 创建图表
            fig, ax = plt.subplots(figsize=figsize)
            
            # 处理多条线的情况
            if isinstance(y, list):
                if not isinstance(color, list):
                    color = [plt.cm.tab10(i) for i in range(len(y))]
                
                for i, col in enumerate(y):
                    data.plot(x=x, y=col, ax=ax, label=col, color=color[i % len(color)], **kwargs)
            else:
                data.plot(x=x, y=y, ax=ax, color=color, **kwargs)
            
            # 设置标题和标签
            if title:
                ax.set_title(title)
            if xlabel:
                ax.set_xlabel(xlabel)
            if ylabel:
                ax.set_ylabel(ylabel)
            
            # 添加网格线
            ax.grid(True, linestyle='--', alpha=0.7)
            
            # 添加图例
            if isinstance(y, list) and len(y) > 1:
                ax.legend()
            
            # 调整布局
            plt.tight_layout()
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制折线图时出错: {e}")
            return None
    
    def plot_pie_chart(self, data, values, names, title=None, figsize=(10, 10),
                      colors=None, autopct='%1.1f%%', save_as=None, **kwargs):
        """绘制饼图
        
        参数:
            data: DataFrame
            values: 值列名
            names: 名称列名
            title: 图表标题
            figsize: 图表大小
            colors: 颜色列表
            autopct: 百分比格式
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            matplotlib图表对象
        """
        try:
            # 准备数据
            if isinstance(data, pd.DataFrame):
                values_data = data[values].values
                names_data = data[names].values
            else:
                values_data = values
                names_data = names
            
            # 创建图表
            fig, ax = plt.subplots(figsize=figsize)
            
            # 绘制饼图
            wedges, texts, autotexts = ax.pie(
                values_data, 
                labels=names_data,
                autopct=autopct,
                colors=colors,
                startangle=90,
                **kwargs
            )
            
            # 设置标题
            if title:
                ax.set_title(title)
            
            # 设置等比例
            ax.axis('equal')
            
            # 调整文本样式
            plt.setp(autotexts, size=10, weight='bold')
            
            # 调整布局
            plt.tight_layout()
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制饼图时出错: {e}")
            return None
    
    def plot_histogram(self, data, column, bins=30, title=None, xlabel=None, ylabel='频率',
                      color='skyblue', kde=True, figsize=(10, 6), save_as=None, **kwargs):
        """绘制直方图
        
        参数:
            data: DataFrame
            column: 列名
            bins: 分箱数量
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            color: 直方图颜色
            kde: 是否显示核密度估计
            figsize: 图表大小
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            matplotlib图表对象
        """
        try:
            # 创建图表
            fig, ax = plt.subplots(figsize=figsize)
            
            # 绘制直方图
            sns.histplot(data=data, x=column, bins=bins, kde=kde, color=color, ax=ax, **kwargs)
            
            # 设置标题和标签
            if title:
                ax.set_title(title)
            if xlabel:
                ax.set_xlabel(xlabel)
            if ylabel:
                ax.set_ylabel(ylabel)
            
            # 调整布局
            plt.tight_layout()
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制直方图时出错: {e}")
            return None
    
    def plot_scatter(self, data, x, y, title=None, xlabel=None, ylabel=None,
                    hue=None, palette='viridis', size=None, figsize=(10, 8), save_as=None, **kwargs):
        """绘制散点图
        
        参数:
            data: DataFrame
            x: x轴列名
            y: y轴列名
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            hue: 分组变量
            palette: 颜色调色板
            size: 点大小变量
            figsize: 图表大小
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            matplotlib图表对象
        """
        try:
            # 创建图表
            fig, ax = plt.subplots(figsize=figsize)
            
            # 绘制散点图
            scatter = sns.scatterplot(
                data=data, 
                x=x, 
                y=y,
                hue=hue,
                palette=palette,
                size=size,
                ax=ax,
                **kwargs
            )
            
            # 设置标题和标签
            if title:
                ax.set_title(title)
            if xlabel:
                ax.set_xlabel(xlabel)
            if ylabel:
                ax.set_ylabel(ylabel)
            
            # 添加网格线
            ax.grid(True, linestyle='--', alpha=0.7)
            
            # 如果有分组变量,调整图例
            if hue:
                plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            
            # 调整布局
            plt.tight_layout()
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制散点图时出错: {e}")
            return None
    
    def plot_heatmap(self, data, title=None, cmap='viridis', annot=True, fmt='.2f',
                    figsize=(12, 10), save_as=None, **kwargs):
        """绘制热力图
        
        参数:
            data: DataFrame或矩阵
            title: 图表标题
            cmap: 颜色映射
            annot: 是否显示数值
            fmt: 数值格式
            figsize: 图表大小
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            matplotlib图表对象
        """
        try:
            # 创建图表
            fig, ax = plt.subplots(figsize=figsize)
            
            # 绘制热力图
            heatmap = sns.heatmap(
                data,
                cmap=cmap,
                annot=annot,
                fmt=fmt,
                linewidths=.5,
                ax=ax,
                **kwargs
            )
            
            # 设置标题
            if title:
                ax.set_title(title)
            
            # 调整布局
            plt.tight_layout()
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制热力图时出错: {e}")
            return None
    
    def plot_box(self, data, x=None, y=None, title=None, xlabel=None, ylabel=None,
               hue=None, palette='Set3', figsize=(12, 8), save_as=None, **kwargs):
        """绘制箱线图
        
        参数:
            data: DataFrame
            x: x轴列名
            y: y轴列名
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            hue: 分组变量
            palette: 颜色调色板
            figsize: 图表大小
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            matplotlib图表对象
        """
        try:
            # 创建图表
            fig, ax = plt.subplots(figsize=figsize)
            
            # 绘制箱线图
            sns.boxplot(
                data=data, 
                x=x, 
                y=y,
                hue=hue,
                palette=palette,
                ax=ax,
                **kwargs
            )
            
            # 设置标题和标签
            if title:
                ax.set_title(title)
            if xlabel:
                ax.set_xlabel(xlabel)
            if ylabel:
                ax.set_ylabel(ylabel)
            
            # 如果有分组变量,调整图例
            if hue:
                plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            
            # 调整布局
            plt.tight_layout()
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制箱线图时出错: {e}")
            return None
    
    def plot_multiple_charts(self, data, chart_configs, title=None, figsize=(15, 10),
                           nrows=None, ncols=None, save_as=None):
        """绘制多个子图
        
        参数:
            data: DataFrame
            chart_configs: 子图配置列表,每个配置是一个字典,包含:
                - 'type': 图表类型 ('bar', 'line', 'scatter', 'hist', 'box', 'pie')
                - 'x', 'y': 数据列名
                - 'title': 子图标题
                - 其他特定图表类型的参数
            title: 总标题
            figsize: 图表大小
            nrows: 行数,如果为None则自动计算
            ncols: 列数,如果为None则自动计算
            save_as: 保存文件名
            
        返回:
            matplotlib图表对象
        """
        try:
            # 确定子图布局
            n_charts = len(chart_configs)
            
            if nrows is None and ncols is None:
                # 自动计算行列数
                ncols = min(3, n_charts)
                nrows = (n_charts + ncols - 1) // ncols
            elif nrows is None:
                nrows = (n_charts + ncols - 1) // ncols
            elif ncols is None:
                ncols = (n_charts + nrows - 1) // nrows
            
            # 创建图表
            fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
            
            # 确保axes是二维数组
            if nrows == 1 and ncols == 1:
                axes = np.array([[axes]])
            elif nrows == 1:
                axes = axes.reshape(1, -1)
            elif ncols == 1:
                axes = axes.reshape(-1, 1)
            
            # 绘制每个子图
            for i, config in enumerate(chart_configs):
                if i >= nrows * ncols:
                    self.logger.warning(f"子图数量超过布局容量,跳过第{i+1}个子图")
                    break
                
                # 获取当前子图的轴
                row, col = i // ncols, i % ncols
                ax = axes[row, col]
                
                # 根据类型绘制不同的图表
                chart_type = config.get('type', 'bar').lower()
                
                if chart_type == 'bar':
                    sns.barplot(
                        data=data,
                        x=config.get('x'),
                        y=config.get('y'),
                        hue=config.get('hue'),
                        palette=config.get('palette', 'viridis'),
                        ax=ax
                    )
                elif chart_type == 'line':
                    if isinstance(config.get('y'), list):
                        for y_col in config.get('y'):
                            data.plot(
                                x=config.get('x'),
                                y=y_col,
                                ax=ax,
                                label=y_col
                            )
                    else:
                        data.plot(
                            x=config.get('x'),
                            y=config.get('y'),
                            ax=ax
                        )
                elif chart_type == 'scatter':
                    sns.scatterplot(
                        data=data,
                        x=config.get('x'),
                        y=config.get('y'),
                        hue=config.get('hue'),
                        palette=config.get('palette', 'viridis'),
                        ax=ax
                    )
                elif chart_type == 'hist':
                    sns.histplot(
                        data=data,
                        x=config.get('x'),
                        bins=config.get('bins', 30),
                        kde=config.get('kde', True),
                        ax=ax
                    )
                elif chart_type == 'box':
                    sns.boxplot(
                        data=data,
                        x=config.get('x'),
                        y=config.get('y'),
                        hue=config.get('hue'),
                        palette=config.get('palette', 'viridis'),
                        ax=ax
                    )
                elif chart_type == 'pie':
                    # 饼图需要特殊处理
                    values = data[config.get('values')].values
                    names = data[config.get('names')].values
                    ax.pie(
                        values,
                        labels=names,
                        autopct='%1.1f%%',
                        startangle=90
                    )
                    ax.axis('equal')
                
                # 设置子图标题和标签
                if 'title' in config:
                    ax.set_title(config['title'])
                if 'xlabel' in config:
                    ax.set_xlabel(config['xlabel'])
                if 'ylabel' in config:
                    ax.set_ylabel(config['ylabel'])
            
            # 隐藏多余的子图
            for i in range(n_charts, nrows * ncols):
                row, col = i // ncols, i % ncols
                fig.delaxes(axes[row, col])
            
            # 设置总标题
            if title:
                fig.suptitle(title, fontsize=16)
                plt.subplots_adjust(top=0.9)
            
            # 调整布局
            plt.tight_layout()
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制多个子图时出错: {e}")
            return None

# 使用示例
def static_visualization_example():
    """静态可视化示例"""
    # 创建示例数据
    np.random.seed(42)
    n_samples = 200
    
    # 生成特征
    X = np.random.randn(n_samples, 3)  # 3个特征
    
    # 生成目标变量(回归)
    y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
    
    # 创建DataFrame
    data = pd.DataFrame(
        X, 
        columns=['feature_1', 'feature_2', 'feature_3']
    )
    data['target'] = y
    
    # 添加一些派生列
    data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
    data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
    data['sales_per_customer'] = data['target'] / np.random.poisson(10, n_samples)
    
    # 创建可视化器
    visualizer = StaticVisualizer(output_dir='visualizations/static')
    
    # 1. 绘制条形图 - 按月份的销售额
    monthly_sales = data.groupby('month')['target'].sum().reset_index()
    monthly_sales['month'] = pd.Categorical(monthly_sales['month'], 
                                           categories=['Jan', 'Feb', 'Mar', 'Apr'],
                                           ordered=True)
    monthly_sales = monthly_sales.sort_values('month')
    
    visualizer.plot_bar_chart(
        data=monthly_sales,
        x='month',
        y='target',
        title='Monthly Sales',
        xlabel='Month',
        ylabel='Total Sales',
        color='skyblue',
        save_as='monthly_sales_bar.png'
    )
    
    # 2. 绘制折线图 - 销售额和利润趋势
    visualizer.plot_line_chart(
        data=data,
        x='feature_1',
        y=['target', 'sales_per_customer'],
        title='Sales and Profit Trends',
        xlabel='Feature 1',
        ylabel='Amount',
        figsize=(14, 7),
        save_as='sales_profit_trend.png'
    )
    
    # 3. 绘制饼图 - 按区域的销售额分布
    region_sales = data.groupby('day_of_week')['target'].sum().reset_index()
    
    visualizer.plot_pie_chart(
        data=region_sales,
        values='target',
        names='day_of_week',
        title='Sales Distribution by Region',
        save_as='region_sales_pie.png'
    )
    
    # 4. 绘制直方图 - 每位客户销售额分布
    visualizer.plot_histogram(
        data=data,
        column='sales_per_customer',
        bins=20,
        title='Distribution of Sales per Customer',
        xlabel='Sales per Customer',
        ylabel='Frequency',
        save_as='sales_per_customer_hist.png'
    )
    
    # 5. 绘制散点图 - 客户数量与销售额的关系
    visualizer.plot_scatter(
        data=data,
        x='feature_2',
        y='target',
        title='Relationship between Number of Customers and Sales',
        xlabel='Number of Customers',
        ylabel='Sales',
        hue='month',
        save_as='customers_sales_scatter.png'
    )
    
    # 6. 绘制热力图 - 相关性矩阵
    correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'target']].corr()
    
    visualizer.plot_heatmap(
        data=correlation_matrix,
        title='Correlation Matrix',
        save_as='correlation_heatmap.png'
    )
    
    # 7. 绘制箱线图 - 按区域的销售额分布
    visualizer.plot_box(
        data=data,
        x='day_of_week',
        y='target',
        title='Sales Distribution by Region',
        xlabel='Region',
        ylabel='Sales',
        save_as='region_sales_box.png'
    )
    
    # 8. 绘制多个子图
    chart_configs = [
        {
            'type': 'bar',
            'x': 'month',
            'y': 'target',
            'title': 'Sales by Month'
        },
        {
            'type': 'line',
            'x': 'feature_1',
            'y': 'target',
            'title': 'Sales Trend'
        },
        {
            'type': 'scatter',
            'x': 'feature_2',
            'y': 'target',
            'title': 'Sales vs Feature 2'
        },
        {
            'type': 'hist',
            'x': 'sales_per_customer',
            'title': 'Sales per Customer Distribution'
        }
    ]
    
    visualizer.plot_multiple_charts(
        data=data,
        chart_configs=chart_configs,
        title='Sales Dashboard',
        save_as='sales_dashboard.png'
    )
    
    print("静态可视化示例完成,图表已保存到 'visualizations/static' 目录")
    
    return {
        'sales_data': data,
        'visualizer': visualizer
    }

if __name__ == "__main__":
    static_visualization_example()

7.2 交互式可视化

交互式可视化允许用户与图表进行交互,例如缩放、悬停查看详情、筛选数据等,主要使用Plotly和Bokeh库:

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
import logging
from pathlib import Path
import json
import plotly.io as pio

class InteractiveVisualizer:
    """交互式可视化类"""
    
    def __init__(self, output_dir='visualizations/interactive'):
        """初始化可视化器
        
        参数:
            output_dir: 输出目录
        """
        self.output_dir = output_dir
        self.logger = self._setup_logger()
        self._setup_style()
        self._ensure_output_dir()
    
    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('InteractiveVisualizer')
        logger.setLevel(logging.INFO)
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        
        return logger
    
    def _setup_style(self):
        """设置可视化样式"""
        # 设置Plotly模板
        pio.templates.default = "plotly_white"
    
    def _ensure_output_dir(self):
        """确保输出目录存在"""
        Path(self.output_dir).mkdir(parents=True, exist_ok=True)
        self.logger.info(f"输出目录: {self.output_dir}")
    
    def save_figure(self, fig, filename, include_plotlyjs='cdn'):
        """保存图表
        
        参数:
            fig: 图表对象
            filename: 文件名
            include_plotlyjs: 是否包含plotly.js
        """
        filepath = Path(self.output_dir) / filename
        
        # 保存为HTML
        if filename.endswith('.html'):
            fig.write_html(filepath, include_plotlyjs=include_plotlyjs)
        # 保存为JSON
        elif filename.endswith('.json'):
            with open(filepath, 'w') as f:
                json.dump(fig.to_dict(), f)
        # 保存为图像
        else:
            fig.write_image(filepath)
        
        self.logger.info(f"图表已保存: {filepath}")
        
        return filepath
    
    def plot_bar_chart(self, data, x, y, title=None, xlabel=None, ylabel=None, 
                      color=None, barmode='group', figsize=(900, 600), 
                      save_as=None, **kwargs):
        """绘制交互式条形图
        
        参数:
            data: DataFrame
            x: x轴列名
            y: y轴列名或列名列表
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            color: 分组变量
            barmode: 条形模式 ('group', 'stack', 'relative', 'overlay')
            figsize: 图表大小 (宽, 高)
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            plotly图表对象
        """
        try:
            # 处理y为列表的情况
            if isinstance(y, list):
                fig = go.Figure()
                
                for col in y:
                    fig.add_trace(go.Bar(
                        x=data[x],
                        y=data[col],
                        name=col
                    ))
                
                fig.update_layout(barmode=barmode)
            else:
                # 使用Plotly Express创建条形图
                fig = px.bar(
                    data, 
                    x=x, 
                    y=y,
                    color=color,
                    barmode=barmode,
                    **kwargs
                )
            
            # 更新布局
            fig.update_layout(
                title=title,
                xaxis_title=xlabel,
                yaxis_title=ylabel,
                width=figsize[0],
                height=figsize[1],
                hovermode='closest'
            )
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制交互式条形图时出错: {e}")
            return None
    
    def plot_line_chart(self, data, x, y, title=None, xlabel=None, ylabel=None,
                       color=None, line_shape='linear', figsize=(900, 600),
                       save_as=None, **kwargs):
        """绘制交互式折线图
        
        参数:
            data: DataFrame
            x: x轴列名
            y: y轴列名或列名列表
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            color: 分组变量
            line_shape: 线条形状 ('linear', 'spline', 'hv', 'vh', 'hvh', 'vhv')
            figsize: 图表大小 (宽, 高)
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            plotly图表对象
        """
        try:
            # 处理y为列表的情况
            if isinstance(y, list):
                fig = go.Figure()
                
                for col in y:
                    fig.add_trace(go.Scatter(
                        x=data[x],
                        y=data[col],
                        mode='lines+markers',
                        name=col,
                        line=dict(shape=line_shape)
                    ))
            else:
                # 使用Plotly Express创建折线图
                fig = px.line(
                    data, 
                    x=x, 
                    y=y,
                    color=color,
                    line_shape=line_shape,
                    **kwargs
                )
                
                # 添加标记点
                fig.update_traces(mode='lines+markers')
            
            # 更新布局
            fig.update_layout(
                title=title,
                xaxis_title=xlabel,
                yaxis_title=ylabel,
                width=figsize[0],
                height=figsize[1],
                hovermode='closest'
            )
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制交互式折线图时出错: {e}")
            return None
    
    def plot_pie_chart(self, data, values, names, title=None, figsize=(800, 800),
                      hole=0, save_as=None, **kwargs):
        """绘制交互式饼图/环形图
        
        参数:
            data: DataFrame
            values: 值列名
            names: 名称列名
            title: 图表标题
            figsize: 图表大小 (宽, 高)
            hole: 中心孔大小 (0-1),0为饼图,>0为环形图
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            plotly图表对象
        """
        try:
            # 使用Plotly Express创建饼图/环形图
            fig = px.pie(
                data, 
                values=values, 
                names=names,
                hole=hole,
                **kwargs
            )
            
            # 更新布局
            fig.update_layout(
                title=title,
                width=figsize[0],
                height=figsize[1]
            )
            
            # 更新轨迹
            fig.update_traces(
                textposition='inside',
                textinfo='percent+label',
                hoverinfo='label+percent+value'
            )
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制交互式饼图时出错: {e}")
            return None
    
    def plot_histogram(self, data, column, bins=30, title=None, xlabel=None, ylabel='频率',
                      color=None, figsize=(900, 600), save_as=None, **kwargs):
        """绘制交互式直方图
        
        参数:
            data: DataFrame
            column: 列名
            bins: 分箱数量
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            color: 分组变量
            figsize: 图表大小 (宽, 高)
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            plotly图表对象
        """
        try:
            # 使用Plotly Express创建直方图
            fig = px.histogram(
                data, 
                x=column,
                color=color,
                nbins=bins,
                marginal='rug',  # 添加边缘分布
                **kwargs
            )
            
            # 更新布局
            fig.update_layout(
                title=title,
                xaxis_title=xlabel,
                yaxis_title=ylabel,
                width=figsize[0],
                height=figsize[1],
                bargap=0.1  # 条形之间的间隙
            )
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制交互式直方图时出错: {e}")
            return None
    
    def plot_scatter(self, data, x, y, title=None, xlabel=None, ylabel=None, 
                    color=None, size=None, hover_name=None, figsize=(900, 600),
                    save_as=None, **kwargs):
        """绘制交互式散点图
        
        参数:
            data: DataFrame
            x: x轴列名
            y: y轴列名
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            color: 分组变量
            size: 点大小变量
            hover_name: 悬停显示的标识列
            figsize: 图表大小 (宽, 高)
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            plotly图表对象
        """
        try:
            # 使用Plotly Express创建散点图
            fig = px.scatter(
                data, 
                x=x, 
                y=y,
                color=color,
                size=size,
                hover_name=hover_name,
                **kwargs
            )
            
            # 更新布局
            fig.update_layout(
                title=title,
                xaxis_title=xlabel,
                yaxis_title=ylabel,
                width=figsize[0],
                height=figsize[1],
                hovermode='closest'
            )
            
            # 添加趋势线
            if 'trendline' not in kwargs:
                fig.update_layout(
                    shapes=[{
                        'type': 'line',
                        'x0': data[x].min(),
                        'y0': data[y].min(),
                        'x1': data[x].max(),
                        'y1': data[y].max(),
                        'line': {
                            'color': 'rgba(0,0,0,0.2)',
                            'width': 2,
                            'dash': 'dash'
                        }
                    }]
                )
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制交互式散点图时出错: {e}")
            return None
    
    def plot_heatmap(self, data, title=None, figsize=(900, 700), 
                    colorscale='Viridis', save_as=None, **kwargs):
        """绘制交互式热力图
        
        参数:
            data: DataFrame或矩阵
            title: 图表标题
            figsize: 图表大小 (宽, 高)
            colorscale: 颜色映射
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            plotly图表对象
        """
        try:
            # 创建热力图
            fig = go.Figure(data=go.Heatmap(
                z=data.values,
                x=data.columns,
                y=data.index,
                colorscale=colorscale,
                **kwargs
            ))
            
            # 更新布局
            fig.update_layout(
                title=title,
                width=figsize[0],
                height=figsize[1]
            )
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制交互式热力图时出错: {e}")
            return None
    
    def plot_box(self, data, x=None, y=None, title=None, xlabel=None, ylabel=None,
               color=None, figsize=(900, 600), save_as=None, **kwargs):
        """绘制交互式箱线图
        
        参数:
            data: DataFrame
            x: x轴列名
            y: y轴列名
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            color: 分组变量
            figsize: 图表大小 (宽, 高)
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            plotly图表对象
        """
        try:
            # 使用Plotly Express创建箱线图
            fig = px.box(
                data, 
                x=x, 
                y=y,
                color=color,
                **kwargs
            )
            
            # 更新布局
            fig.update_layout(
                title=title,
                xaxis_title=xlabel,
                yaxis_title=ylabel,
                width=figsize[0],
                height=figsize[1]
            )
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制交互式箱线图时出错: {e}")
            return None
    
    def plot_bubble(self, data, x, y, size, title=None, xlabel=None, ylabel=None,
                   color=None, hover_name=None, figsize=(900, 600), save_as=None, **kwargs):
        """绘制交互式气泡图
        
        参数:
            data: DataFrame
            x: x轴列名
            y: y轴列名
            size: 气泡大小列名
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            color: 分组变量
            hover_name: 悬停显示的标识列
            figsize: 图表大小 (宽, 高)
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            plotly图表对象
        """
        try:
            # 使用Plotly Express创建气泡图
            fig = px.scatter(
                data, 
                x=x, 
                y=y,
                size=size,
                color=color,
                hover_name=hover_name,
                **kwargs
            )
            
            # 更新布局
            fig.update_layout(
                title=title,
                xaxis_title=xlabel,
                yaxis_title=ylabel,
                width=figsize[0],
                height=figsize[1],
                hovermode='closest'
            )
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制交互式气泡图时出错: {e}")
            return None
    
    def plot_3d_scatter(self, data, x, y, z, title=None, xlabel=None, ylabel=None, zlabel=None,
                       color=None, size=None, hover_name=None, figsize=(900, 700),
                       save_as=None, **kwargs):
        """绘制交互式3D散点图
        
        参数:
            data: DataFrame
            x: x轴列名
            y: y轴列名
            z: z轴列名
            title: 图表标题
            xlabel: x轴标签
            ylabel: y轴标签
            zlabel: z轴标签
            color: 分组变量
            size: 点大小变量
            hover_name: 悬停显示的标识列
            figsize: 图表大小 (宽, 高)
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            plotly图表对象
        """
        try:
            # 使用Plotly Express创建3D散点图
            fig = px.scatter_3d(
                data, 
                x=x, 
                y=y,
                z=z,
                color=color,
                size=size,
                hover_name=hover_name,
                **kwargs
            )
            
            # 更新布局
            fig.update_layout(
                title=title,
                scene=dict(
                    xaxis_title=xlabel,
                    yaxis_title=ylabel,
                    zaxis_title=zlabel
                ),
                width=figsize[0],
                height=figsize[1]
            )
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制交互式3D散点图时出错: {e}")
            return None
    
    def plot_choropleth_map(self, data, locations, color, title=None, 
                           location_mode='ISO-3', figsize=(900, 600),
                           colorscale='Viridis', save_as=None, **kwargs):
        """绘制交互式地理热力图
        
        参数:
            data: DataFrame
            locations: 地理位置列名
            color: 颜色值列名
            title: 图表标题
            location_mode: 地理位置模式 ('ISO-3', 'country names', 等)
            figsize: 图表大小 (宽, 高)
            colorscale: 颜色映射
            save_as: 保存文件名
            **kwargs: 其他参数
            
        返回:
            plotly图表对象
        """
        try:
            # 使用Plotly Express创建地理热力图
            fig = px.choropleth(
                data, 
                locations=locations,
                color=color,
                locationmode=location_mode,
                color_continuous_scale=colorscale,
                **kwargs
            )
            
            # 更新布局
            fig.update_layout(
                title=title,
                width=figsize[0],
                height=figsize[1],
                geo=dict(
                    showframe=False,
                    showcoastlines=True,
                    projection_type='equirectangular'
                )
            )
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制交互式地理热力图时出错: {e}")
            return None

    def plot_multiple_charts(self, chart_configs, title=None, figsize=(1000, 800),
                           rows=None, cols=None, subplot_titles=None, save_as=None):
        """绘制多个子图
        
        参数:
            chart_configs: 子图配置列表,每个配置是一个字典,包含:
                - 'data': 数据
                - 'type': 图表类型 ('bar', 'line', 'scatter', 'pie', 'box', 'heatmap', 等)
                - 'x', 'y': 数据列名
                - 'row', 'col': 子图位置
                - 其他特定图表类型的参数
            title: 总标题
            figsize: 图表大小 (宽, 高)
            rows: 行数
            cols: 列数
            subplot_titles: 子图标题列表
            save_as: 保存文件名
            
        返回:
            plotly图表对象
        """
        try:
            # 确定子图布局
            if rows is None or cols is None:
                # 查找最大的row和col值
                max_row = max([config.get('row', 1) for config in chart_configs])
                max_col = max([config.get('col', 1) for config in chart_configs])
                rows = max(rows or 0, max_row)
                cols = max(cols or 0, max_col)
            
            # 创建子图
            fig = make_subplots(
                rows=rows, 
                cols=cols,
                subplot_titles=subplot_titles,
                specs=[[{"type": "xy"} for _ in range(cols)] for _ in range(rows)]
            )
            
            # 添加每个子图
            for config in chart_configs:
                data = config.get('data')
                chart_type = config.get('type', 'scatter').lower()
                row = config.get('row', 1)
                col = config.get('col', 1)
                
                if chart_type == 'bar':
                    trace = go.Bar(
                        x=data[config.get('x')],
                        y=data[config.get('y')],
                        name=config.get('name', config.get('y')),
                        marker_color=config.get('color')
                    )
                elif chart_type == 'line':
                    trace = go.Scatter(
                        x=data[config.get('x')],
                        y=data[config.get('y')],
                        mode='lines+markers',
                        name=config.get('name', config.get('y')),
                        line=dict(color=config.get('color'))
                    )
                elif chart_type == 'scatter':
                    trace = go.Scatter(
                        x=data[config.get('x')],
                        y=data[config.get('y')],
                        mode='markers',
                        name=config.get('name', config.get('y')),
                        marker=dict(
                            color=config.get('color'),
                            size=config.get('size', 10)
                        )
                    )
                elif chart_type == 'pie':
                    trace = go.Pie(
                        values=data[config.get('values')],
                        labels=data[config.get('names')],
                        name=config.get('name', '')
                    )
                elif chart_type == 'box':
                    trace = go.Box(
                        x=data[config.get('x')] if 'x' in config else None,
                        y=data[config.get('y')],
                        name=config.get('name', config.get('y'))
                    )
                elif chart_type == 'heatmap':
                    # 热力图需要特殊处理
                    if isinstance(data, pd.DataFrame):
                        z_data = data.values
                        x_data = data.columns
                        y_data = data.index
                    else:
                        z_data = data
                        x_data = config.get('x')
                        y_data = config.get('y')
                    
                    trace = go.Heatmap(
                        z=z_data,
                        x=x_data,
                        y=y_data,
                        colorscale=config.get('colorscale', 'Viridis')
                    )
                else:
                    self.logger.warning(f"未知的图表类型: {chart_type}")
                    continue
                
                fig.add_trace(trace, row=row, col=col)
                
                # 更新轴标签
                if 'xlabel' in config:
                    fig.update_xaxes(title_text=config['xlabel'], row=row, col=col)
                if 'ylabel' in config:
                    fig.update_yaxes(title_text=config['ylabel'], row=row, col=col)
            
            # 更新布局
            fig.update_layout(
                title=title,
                width=figsize[0],
                height=figsize[1],
                showlegend=True
            )
            
            # 保存图表
            if save_as:
                self.save_figure(fig, save_as)
            
            return fig
            
        except Exception as e:
            self.logger.error(f"绘制多个子图时出错: {e}")
            return None

# 使用示例
def interactive_visualization_example():
    """交互式可视化示例"""
    # 创建示例数据
    np.random.seed(42)
    n_samples = 200
    
    # 生成特征
    X = np.random.randn(n_samples, 3)  # 3个特征
    
    # 生成目标变量(回归)
    y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
    
    # 创建DataFrame
    data = pd.DataFrame(
        X, 
        columns=['feature_1', 'feature_2', 'feature_3']
    )
    data['target'] = y
    
    # 添加一些派生列
    data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
    data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu', 'Fri'], n_samples)
    data['region'] = np.random.choice(['North', 'South', 'East', 'West'], n_samples)
    data['sales'] = data['target'] * 100 + 500
    data['profit'] = data['sales'] * np.random.uniform(0.1, 0.3, n_samples)
    data['customers'] = np.random.poisson(50, n_samples)
    data['sales_per_customer'] = data['sales'] / data['customers']
    
    # 创建一些国家数据
    countries = ['USA', 'CAN', 'MEX', 'BRA', 'ARG', 'GBR', 'FRA', 'DEU', 'ITA', 'ESP', 
                'RUS', 'CHN', 'JPN', 'IND', 'AUS']
    country_codes = ['USA', 'CAN', 'MEX', 'BRA', 'ARG', 'GBR', 'FRA', 'DEU', 'ITA', 'ESP', 
                    'RUS', 'CHN', 'JPN', 'IND', 'AUS']
    country_data = pd.DataFrame({
        'country': countries,
        'code': country_codes,
        'gdp': np.random.uniform(100, 1000, len(countries)),
        'population': np.random.uniform(10, 500, len(countries))
    })
    
    # 创建可视化器
    visualizer = InteractiveVisualizer(output_dir='visualizations/interactive')
    
    # 1. 绘制交互式条形图 - 按月份的销售额
    monthly_sales = data.groupby('month')['sales'].sum().reset_index()
    monthly_sales['month'] = pd.Categorical(monthly_sales['month'], 
                                           categories=['Jan', 'Feb', 'Mar', 'Apr'],
                                           ordered=True)
    monthly_sales = monthly_sales.sort_values('month')
    
    bar_fig = visualizer.plot_bar_chart(
        data=monthly_sales,
        x='month',
        y='sales',
        title='Monthly Sales',
        xlabel='Month',
        ylabel='Total Sales',
        save_as='monthly_sales_bar.html'
    )
    
    # 2. 绘制交互式折线图 - 销售额和利润趋势
    line_fig = visualizer.plot_line_chart(
        data=data.sort_values('feature_1').iloc[:50],  # 使用部分数据
        x='feature_1',
        y=['sales', 'profit'],
        title='Sales and Profit Trends',
        xlabel='Feature 1',
        ylabel='Amount',
        save_as='sales_profit_trend.html'
    )
    
    # 3. 绘制交互式饼图 - 按区域的销售额分布
    region_sales = data.groupby('region')['sales'].sum().reset_index()
    
    pie_fig = visualizer.plot_pie_chart(
        data=region_sales,
        values='sales',
        names='region',
        title='Sales Distribution by Region',
        save_as='region_sales_pie.html'
    )
    
    # 4. 绘制交互式环形图 - 按星期几的销售额分布
    day_sales = data.groupby('day_of_week')['sales'].sum().reset_index()
    
    donut_fig = visualizer.plot_pie_chart(
        data=day_sales,
        values='sales',
        names='day_of_week',
        title='Sales Distribution by Day of Week',
        hole=0.4,  # 环形图
        save_as='day_sales_donut.html'
    )
    
    # 5. 绘制交互式直方图 - 每位客户销售额分布
    hist_fig = visualizer.plot_histogram(
        data=data,
        column='sales_per_customer',
        bins=20,
        title='Distribution of Sales per Customer',
        xlabel='Sales per Customer',
        ylabel='Frequency',
        color='region',  # 按区域分组
        save_as='sales_per_customer_hist.html'
    )
    
    # 6. 绘制交互式散点图 - 客户数量与销售额的关系
    scatter_fig = visualizer.plot_scatter(
        data=data,
        x='customers',
        y='sales',
        title='Relationship between Number of Customers and Sales',
        xlabel='Number of Customers',
        ylabel='Sales',
        color='region',
        size='profit',  # 使用利润作为点大小
        hover_name='month',  # 悬停显示月份
        save_as='customers_sales_scatter.html'
    )
    
    # 7. 绘制交互式热力图 - 相关性矩阵
    correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'sales', 'profit', 'customers']].corr()
    
    heatmap_fig = visualizer.plot_heatmap(
        data=correlation_matrix,
        title='Correlation Matrix',
        save_as='correlation_heatmap.html'
    )
    
    # 8. 绘制交互式箱线图 - 按区域的销售额分布
    box_fig = visualizer.plot_box(
        data=data,
        x='region',
        y='sales',
        title='Sales Distribution by Region',
        xlabel='Region',
        ylabel='Sales',
        color='region',
        save_as='region_sales_box.html'
    )
    
    # 9. 绘制交互式气泡图 - 特征与销售额和利润的关系
    bubble_fig = visualizer.plot_bubble(
        data=data,
        x='feature_1',
        y='feature_2',
        size='sales',
        color='region',
        title='Feature Relationships with Sales',
        xlabel='Feature 1',
        ylabel='Feature 2',
        hover_name='month',
        save_as='feature_sales_bubble.html'
    )
    
    # 10. 绘制交互式3D散点图 - 三个特征的关系
    scatter_3d_fig = visualizer.plot_3d_scatter(
        data=data,
        x='feature_1',
        y='feature_2',
        z='feature_3',
        color='sales',
        size='profit',
        title='3D Relationship between Features',
        xlabel='Feature 1',
        ylabel='Feature 2',
        zlabel='Feature 3',
        save_as='features_3d_scatter.html'
    )
    
    # 11. 绘制交互式地理热力图 - 国家GDP分布
    choropleth_fig = visualizer.plot_choropleth_map(
        data=country_data,
        locations='code',
        color='gdp',
        title='GDP by Country',
        location_mode='ISO-3',
        color_continuous_scale='Viridis',
        save_as='country_gdp_map.html'
    )
    
    # 12. 绘制多个子图 - 销售仪表盘
    chart_configs = [
        {
            'data': monthly_sales,
            'type': 'bar',
            'x': 'month',
            'y': 'sales',
            'row': 1,
            'col': 1,
            'name': 'Monthly Sales',
            'xlabel': 'Month',
            'ylabel': 'Sales'
        },
        {
            'data': data.sort_values('feature_1').iloc[:50],
            'type': 'line',
            'x': 'feature_1',
            'y': 'sales',
            'row': 1,
            'col': 2,
            'name': 'Sales Trend',
            'xlabel': 'Feature 1',
            'ylabel': 'Sales'
        },
        {
            'data': data,
            'type': 'scatter',
            'x': 'customers',
            'y': 'sales',
            'row': 2,
            'col': 1,
            'name': 'Customers vs Sales',
            'xlabel': 'Customers',
            'ylabel': 'Sales'
        },
        {
            'data': correlation_matrix,
            'type': 'heatmap',
            'row': 2,
            'col': 2,
            'name': 'Correlation'
        }
    ]
    
    subplot_titles = ['Monthly Sales', 'Sales Trend', 'Customers vs Sales', 'Correlation Matrix']
    
    dashboard_fig = visualizer.plot_multiple_charts(
        chart_configs=chart_configs,
        title='Sales Dashboard',
        rows=2,
        cols=2,
        subplot_titles=subplot_titles,
        save_as='sales_dashboard.html'
    )
    
    print("交互式可视化示例完成,图表已保存到 'visualizations/interactive' 目录")
    
    return {
        'data': data,
        'country_data': country_data,
        'visualizer': visualizer,
        'figures': {
            'bar': bar_fig,
            'line': line_fig,
            'pie': pie_fig,
            'donut': donut_fig,
            'hist': hist_fig,
            'scatter': scatter_fig,
            'heatmap': heatmap_fig,
            'box': box_fig,
            'bubble': bubble_fig,
            'scatter_3d': scatter_3d_fig,
            'choropleth': choropleth_fig,
            'dashboard': dashboard_fig
        }
    }

if __name__ == "__main__":
    interactive_visualization_example()

交互式仪表盘功能

# 交互式仪表盘模块
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
import os
from pathlib import Path
import logging

class DashboardBuilder:
    """交互式仪表盘构建器"""
    
    def __init__(self, title="数据分析仪表盘", theme="plotly_white"):
        """初始化仪表盘构建器
        
        参数:
            title: 仪表盘标题
            theme: 仪表盘主题
        """
        self.title = title
        self.theme = theme
        self.app = dash.Dash(__name__, suppress_callback_exceptions=True)
        self.app.title = title
        self.visualizer = InteractiveVisualizer()
        self.logger = logging.getLogger(__name__)
        
        # 设置Plotly主题
        pio.templates.default = theme
    
    def create_layout(self, components):
        """创建仪表盘布局
        
        参数:
            components: 组件列表,每个组件是一个字典,包含:
                - 'type': 组件类型 ('graph', 'table', 'control', 等)
                - 'id': 组件ID
                - 'title': 组件标题
                - 'width': 组件宽度 (1-12)
                - 其他特定组件类型的参数
                
        返回:
            Dash应用布局
        """
        try:
            # 创建页面布局
            layout = html.Div([
                # 标题
                html.H1(self.title, style={'textAlign': 'center', 'marginBottom': 30}),
                
                # 内容容器
                html.Div([
                    # 为每个组件创建一个Div
                    html.Div([
                        # 组件标题
                        html.H3(component.get('title', f"Component {i+1}"), 
                              style={'marginBottom': 15}),
                        
                        # 根据组件类型创建不同的内容
                        self._create_component(component)
                    ], className=f"col-{component.get('width', 12)}",
                       style={'padding': '10px'})
                    
                    for i, component in enumerate(components)
                ], className='row')
            ], className='container-fluid')
            
            self.app.layout = layout
            return layout
            
        except Exception as e:
            self.logger.error(f"创建仪表盘布局时出错: {e}")
            return html.Div(f"创建仪表盘布局时出错: {e}")
    
    def _create_component(self, component):
        """根据组件类型创建组件
        
        参数:
            component: 组件配置字典
            
        返回:
            Dash组件
        """
        try:
            component_type = component.get('type', '').lower()
            component_id = component.get('id', f"component-{id(component)}")
            
            if component_type == 'graph':
                # 创建图表组件
                return dcc.Graph(
                    id=component_id,
                    figure=component.get('figure', {}),
                    style={'height': component.get('height', 400)}
                )
                
            elif component_type == 'table':
                # 创建表格组件
                data = component.get('data', pd.DataFrame())
                return html.Div([
                    dash.dash_table.DataTable(
                        id=component_id,
                        columns=[{"name": i, "id": i} for i in data.columns],
                        data=data.to_dict('records'),
                        page_size=component.get('page_size', 10),
                        style_table={'overflowX': 'auto'},
                        style_cell={
                            'textAlign': 'left',
                            'padding': '10px',
                            'minWidth': '100px', 'width': '150px', 'maxWidth': '300px',
                            'whiteSpace': 'normal',
                            'height': 'auto'
                        },
                        style_header={
                            'backgroundColor': 'rgb(230, 230, 230)',
                            'fontWeight': 'bold'
                        }
                    )
                ])
                
            elif component_type == 'control':
                # 创建控制组件
                control_subtype = component.get('control_type', '').lower()
                
                if control_subtype == 'dropdown':
                    return dcc.Dropdown(
                        id=component_id,
                        options=[{'label': str(opt), 'value': opt} 
                                for opt in component.get('options', [])],
                        value=component.get('value'),
                        multi=component.get('multi', False),
                        placeholder=component.get('placeholder', 'Select an option')
                    )
                    
                elif control_subtype == 'slider':
                    return dcc.Slider(
                        id=component_id,
                        min=component.get('min', 0),
                        max=component.get('max', 100),
                        step=component.get('step', 1),
                        value=component.get('value', 50),
                        marks={i: str(i) for i in range(
                            component.get('min', 0), 
                            component.get('max', 100) + 1, 
                            component.get('mark_step', 10)
                        )}
                    )
                    
                elif control_subtype == 'radio':
                    return dcc.RadioItems(
                        id=component_id,
                        options=[{'label': str(opt), 'value': opt} 
                                for opt in component.get('options', [])],
                        value=component.get('value'),
                        inline=component.get('inline', True)
                    )
                    
                elif control_subtype == 'checklist':
                    return dcc.Checklist(
                        id=component_id,
                        options=[{'label': str(opt), 'value': opt} 
                                for opt in component.get('options', [])],
                        value=component.get('value', []),
                        inline=component.get('inline', True)
                    )
                    
                elif control_subtype == 'date':
                    return dcc.DatePickerSingle(
                        id=component_id,
                        date=component.get('date'),
                        min_date_allowed=component.get('min_date'),
                        max_date_allowed=component.get('max_date')
                    )
                    
                elif control_subtype == 'daterange':
                    return dcc.DatePickerRange(
                        id=component_id,
                        start_date=component.get('start_date'),
                        end_date=component.get('end_date'),
                        min_date_allowed=component.get('min_date'),
                        max_date_allowed=component.get('max_date')
                    )
                
                else:
                    return html.Div(f"未知的控制类型: {control_subtype}")
                
            elif component_type == 'text':
                # 创建文本组件
                return html.Div([
                    html.P(component.get('text', ''), 
                          style={'fontSize': component.get('font_size', 16)})
                ])
                
            elif component_type == 'html':
                # 创建自定义HTML组件
                return html.Div([
                    html.Div(component.get('html', ''), 
                            dangerously_set_inner_html=True)
                ])
                
            else:
                return html.Div(f"未知的组件类型: {component_type}")
                
        except Exception as e:
            self.logger.error(f"创建组件时出错: {e}")
            return html.Div(f"创建组件时出错: {e}")
    
    def add_callback(self, outputs, inputs, state=None):
        """添加回调函数
        
        参数:
            outputs: 输出组件列表,每个元素是一个元组 (component_id, component_property)
            inputs: 输入组件列表,每个元素是一个元组 (component_id, component_property)
            state: 状态组件列表,每个元素是一个元组 (component_id, component_property)
            
        返回:
            装饰器函数
        """
        try:
            # 转换为Dash输出格式
            dash_outputs = [Output(component_id, component_property) 
                          for component_id, component_property in outputs]
            
            # 转换为Dash输入格式
            dash_inputs = [Input(component_id, component_property) 
                         for component_id, component_property in inputs]
            
            # 转换为Dash状态格式
            dash_state = []
            if state:
                dash_state = [dash.dependencies.State(component_id, component_property) 
                            for component_id, component_property in state]
            
            # 返回Dash回调装饰器
            return self.app.callback(dash_outputs, dash_inputs, dash_state)
            
        except Exception as e:
            self.logger.error(f"添加回调函数时出错: {e}")
            return None
    
    def run_server(self, debug=True, port=8050, host='0.0.0.0'):
        """运行仪表盘服务器
        
        参数:
            debug: 是否启用调试模式
            port: 服务器端口
            host: 服务器主机
        """
        try:
            self.app.run_server(debug=debug, port=port, host=host)
        except Exception as e:
            self.logger.error(f"运行仪表盘服务器时出错: {e}")

# 使用示例
def interactive_dashboard_example():
    """交互式仪表盘示例"""
    # 创建示例数据
    np.random.seed(42)
    n_samples = 200
    
    # 生成特征
    X = np.random.randn(n_samples, 3)  # 3个特征
    
    # 生成目标变量(回归)
    y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
    
    # 创建DataFrame
    data = pd.DataFrame(
        X, 
        columns=['feature_1', 'feature_2', 'feature_3']
    )
    data['target'] = y
    
    # 添加一些派生列
    data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
    data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu', 'Fri'], n_samples)
    data['region'] = np.random.choice(['North', 'South', 'East', 'West'], n_samples)
    data['sales'] = data['target'] * 100 + 500
    data['profit'] = data['sales'] * np.random.uniform(0.1, 0.3, n_samples)
    data['customers'] = np.random.poisson(50, n_samples)
    data['sales_per_customer'] = data['sales'] / data['customers']
    
    # 创建可视化器
    visualizer = InteractiveVisualizer()
    
    # 创建一些图表
    monthly_sales = data.groupby('month')['sales'].sum().reset_index()
    monthly_sales['month'] = pd.Categorical(monthly_sales['month'], 
                                           categories=['Jan', 'Feb', 'Mar', 'Apr'],
                                           ordered=True)
    monthly_sales = monthly_sales.sort_values('month')
    
    bar_fig = visualizer.plot_bar_chart(
        data=monthly_sales,
        x='month',
        y='sales',
        title='Monthly Sales',
        xlabel='Month',
        ylabel='Total Sales'
    )
    
    region_sales = data.groupby('region')['sales'].sum().reset_index()
    pie_fig = visualizer.plot_pie_chart(
        data=region_sales,
        values='sales',
        names='region',
        title='Sales Distribution by Region'
    )
    
    scatter_fig = visualizer.plot_scatter(
        data=data,
        x='customers',
        y='sales',
        title='Relationship between Number of Customers and Sales',
        xlabel='Number of Customers',
        ylabel='Sales',
        color='region',
        size='profit'
    )
    
    correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'sales', 'profit', 'customers']].corr()
    heatmap_fig = visualizer.plot_heatmap(
        data=correlation_matrix,
        title='Correlation Matrix'
    )
    
    # 创建仪表盘构建器
    dashboard = DashboardBuilder(title="销售数据分析仪表盘")
    
    # 定义仪表盘组件
    components = [
        {
            'type': 'control',
            'id': 'region-filter',
            'title': '区域筛选',
            'control_type': 'dropdown',
            'options': ['All'] + list(data['region'].unique()),
            'value': 'All',
            'width': 3
        },
        {
            'type': 'control',
            'id': 'month-filter',
            'title': '月份筛选',
            'control_type': 'checklist',
            'options': list(data['month'].unique()),
            'value': list(data['month'].unique()),
            'width': 9
        },
        {
            'type': 'graph',
            'id': 'monthly-sales-chart',
            'title': '月度销售额',
            'figure': bar_fig,
            'width': 6,
            'height': 400
        },
        {
            'type': 'graph',
            'id': 'region-sales-chart',
            'title': '区域销售额分布',
            'figure': pie_fig,
            'width': 6,
            'height': 400
        },
        {
            'type': 'graph',
            'id': 'customer-sales-chart',
            'title': '客户数量与销售额关系',
            'figure': scatter_fig,
            'width': 6,
            'height': 400
        },
        {
            'type': 'graph',
            'id': 'correlation-matrix',
            'title': '相关性矩阵',
            'figure': heatmap_fig,
            'width': 6,
            'height': 400
        },
        {
            'type': 'table',
            'id': 'sales-table',
            'title': '销售数据表',
            'data': data[['month', 'region', 'sales', 'profit', 'customers']].head(10),
            'width': 12,
            'page_size': 10
        }
    ]
    
    # 创建仪表盘布局
    dashboard.create_layout(components)
    
    # 添加回调函数 - 区域筛选
    @dashboard.add_callback(
        outputs=[('sales-table', 'data')],
        inputs=[('region-filter', 'value'), ('month-filter', 'value')]
    )
    def update_table(region, months):
        filtered_data = data.copy()
        
        # 筛选区域
        if region != 'All':
            filtered_data = filtered_data[filtered_data['region'] == region]
        
        # 筛选月份
        if months:
            filtered_data = filtered_data[filtered_data['month'].isin(months)]
        
        return [filtered_data[['month', 'region', 'sales', 'profit', 'customers']].head(10).to_dict('records')]
    
    # 添加回调函数 - 更新图表
    @dashboard.add_callback(
        outputs=[
            ('monthly-sales-chart', 'figure'),
            ('region-sales-chart', 'figure'),
            ('customer-sales-chart', 'figure')
        ],
        inputs=[('region-filter', 'value'), ('month-filter', 'value')]
    )
    def update_charts(region, months):
        filtered_data = data.copy()
        
        # 筛选区域
        if region != 'All':
            filtered_data = filtered_data[filtered_data['region'] == region]
        
        # 筛选月份
        if months:
            filtered_data = filtered_data[filtered_data['month'].isin(months)]
        
        # 更新月度销售额图表
        monthly_sales = filtered_data.groupby('month')['sales'].sum().reset_index()
        monthly_sales['month'] = pd.Categorical(monthly_sales['month'], 
                                              categories=['Jan', 'Feb', 'Mar', 'Apr'],
                                              ordered=True)
        monthly_sales = monthly_sales.sort_values('month')
        
        bar_fig = visualizer.plot_bar_chart(
            data=monthly_sales,
            x='month',
            y='sales',
            title='Monthly Sales',
            xlabel='Month',
            ylabel='Total Sales'
        )
        
        # 更新区域销售额分布图表
        region_sales = filtered_data.groupby('region')['sales'].sum().reset_index()
        pie_fig = visualizer.plot_pie_chart(
            data=region_sales,
            values='sales',
            names='region',
            title='Sales Distribution by Region'
        )
        
        # 更新客户数量与销售额关系图表
        scatter_fig = visualizer.plot_scatter(
            data=filtered_data,
            x='customers',
            y='sales',
            title='Relationship between Number of Customers and Sales',
            xlabel='Number of Customers',
            ylabel='Sales',
            color='region',
            size='profit'
        )
        
        return [bar_fig, pie_fig, scatter_fig]
    
    # 运行仪表盘
    print("启动交互式仪表盘,请访问 http://127.0.0.1:8050/")
    dashboard.run_server(debug=True)

if __name__ == "__main__":
    interactive_dashboard_example()

# 可视化模块整合
class VisualizationManager:
    """可视化管理器,整合静态和交互式可视化"""
    
    def __init__(self, output_dir='visualizations'):
        """初始化可视化管理器
        
        参数:
            output_dir: 输出目录
        """
        # 创建静态和交互式可视化器
        self.static_visualizer = StaticVisualizer(output_dir=os.path.join(output_dir, 'static'))
        self.interactive_visualizer = InteractiveVisualizer(output_dir=os.path.join(output_dir, 'interactive'))
        self.output_dir = output_dir
        self.logger = logging.getLogger(__name__)
        
        # 确保输出目录存在
        os.makedirs(output_dir, exist_ok=True)
    
    def create_visualization(self, data, chart_type, static=True, interactive=True, **kwargs):
        """创建可视化图表
        
        参数:
            data: 输入数据
            chart_type: 图表类型 ('bar', 'line', 'scatter', 'pie', 'box', 'heatmap', 等)
            static: 是否创建静态图表
            interactive: 是否创建交互式图表
            **kwargs: 其他参数
            
        返回:
            字典,包含静态和交互式图表对象
        """
        try:
            result = {}
            
            # 根据图表类型选择相应的方法
            method_name = f"plot_{chart_type}"
            
            # 创建静态图表
            if static and hasattr(self.static_visualizer, method_name):
                static_method = getattr(self.static_visualizer, method_name)
                static_fig = static_method(data=data, **kwargs)
                result['static'] = static_fig
            
            # 创建交互式图表
            if interactive and hasattr(self.interactive_visualizer, method_name):
                interactive_method = getattr(self.interactive_visualizer, method_name)
                interactive_fig = interactive_method(data=data, **kwargs)
                result['interactive'] = interactive_fig
            
            return result
            
        except Exception as e:
            self.logger.error(f"创建可视化图表时出错: {e}")
            return {}
    
    def create_dashboard(self, data, config, title="数据可视化仪表盘"):
        """创建交互式仪表盘
        
        参数:
            data: 输入数据
            config: 仪表盘配置,包含组件列表
            title: 仪表盘标题
            
        返回:
            DashboardBuilder对象
        """
        try:
            # 创建仪表盘构建器
            dashboard = DashboardBuilder(title=title)
            
            # 创建组件
            components = []
            
            for component_config in config:
                component_type = component_config.get('type')
                
                if component_type == 'graph':
                    # 创建图表组件
                    chart_type = component_config.get('chart_type')
                    chart_params = component_config.get('params', {})
                    
                    # 创建图表
                    chart_result = self.create_visualization(
                        data=component_config.get('data', data),
                        chart_type=chart_type,
                        static=False,
                        interactive=True,
                        **chart_params
                    )
                    
                    # 添加到组件列表
                    if 'interactive' in chart_result:
                        components.append({
                            'type': 'graph',
                            'id': component_config.get('id', f"graph-{len(components)}"),
                            'title': component_config.get('title', f"{chart_type.capitalize()} Chart"),
                            'figure': chart_result['interactive'],
                            'width': component_config.get('width', 6),
                            'height': component_config.get('height', 400)
                        })
                
                elif component_type in ['control', 'table', 'text', 'html']:
                    # 直接添加其他类型的组件
                    components.append(component_config)
            
            # 创建仪表盘布局
            dashboard.create_layout(components)
            
            return dashboard
            
        except Exception as e:
            self.logger.error(f"创建仪表盘时出错: {e}")
            return None
    
    def export_visualizations(self, visualizations, format='html'):
        """导出可视化图表
        
        参数:
            visualizations: 可视化图表字典
            format: 导出格式 ('html', 'png', 'pdf', 等)
            
        返回:
            导出文件路径列表
        """
        try:
            export_paths = []
            
            for name, viz_dict in visualizations.items():
                # 导出静态图表
                if 'static' in viz_dict and viz_dict['static'] is not None:
                    static_path = os.path.join(self.output_dir, 'exports', 'static', f"{name}.{format}")
                    os.makedirs(os.path.dirname(static_path), exist_ok=True)
                    
                    if format == 'html':
                        # 对于Matplotlib图表,需要先保存为图像
                        temp_path = os.path.join(self.output_dir, 'exports', 'static', f"{name}.png")
                        viz_dict['static'].savefig(temp_path)
                        
                        # 创建HTML包装
                        with open(static_path, 'w') as f:
                            f.write(f"""
                            <html>
                            <head><title>{name} - Static Visualization</title></head>
                            <body>
                                <h1>{name}</h1>
                                <img src="{os.path.basename(temp_path)}" alt="{name}">
                            </body>
                            </html>
                            """)
                    else:
                        viz_dict['static'].savefig(static_path)
                    
                    export_paths.append(static_path)
                
                # 导出交互式图表
                if 'interactive' in viz_dict and viz_dict['interactive'] is not None:
                    interactive_path = os.path.join(self.output_dir, 'exports', 'interactive', f"{name}.html")
                    os.makedirs(os.path.dirname(interactive_path), exist_ok=True)
                    
                    # 保存Plotly图表
                    viz_dict['interactive'].write_html(interactive_path)
                    export_paths.append(interactive_path)
            
            return export_paths
            
        except Exception as e:
            self.logger.error(f"导出可视化图表时出错: {e}")
            return []

# 使用示例
def visualization_manager_example():
    """可视化管理器示例"""
    # 创建示例数据
    np.random.seed(42)
    n_samples = 200
    
    # 生成特征
    X = np.random.randn(n_samples, 3)  # 3个特征
    
    # 生成目标变量(回归)
    y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
    
    # 创建DataFrame
    data = pd.DataFrame(
        X, 
        columns=['feature_1', 'feature_2', 'feature_3']
    )
    data['target'] = y
    
    # 添加一些派生列
    data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
    data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu', 'Fri'], n_samples)
    data['region'] = np.random.choice(['North', 'South', 'East', 'West'], n_samples)
    data['sales'] = data['target'] * 100 + 500
    data['profit'] = data['sales'] * np.random.uniform(0.1, 0.3, n_samples)
    data['customers'] = np.random.poisson(50, n_samples)
    data['sales_per_customer'] = data['sales'] / data['customers']
    
    # 创建可视化管理器
    viz_manager = VisualizationManager(output_dir='visualizations')
    
    # 创建各种图表
    visualizations = {}
    
    # 1. 条形图
    monthly_sales = data.groupby('month')['sales'].sum().reset_index()
    monthly_sales['month'] = pd.Categorical(monthly_sales['month'], 
                                           categories=['Jan', 'Feb', 'Mar', 'Apr'],
                                           ordered=True)
    monthly_sales = monthly_sales.sort_values('month')
    
    bar_charts = viz_manager.create_visualization(
        data=monthly_sales,
        chart_type='bar_chart',
        x='month',
        y='sales',
        title='Monthly Sales',
        xlabel='Month',
        ylabel='Total Sales',
        save_as='monthly_sales'
    )
    visualizations['monthly_sales'] = bar_charts
    
    # 2. 散点图
    scatter_charts = viz_manager.create_visualization(
        data=data,
        chart_type='scatter',
        x='customers',
        y='sales',
        title='Relationship between Number of Customers and Sales',
        xlabel='Number of Customers',
        ylabel='Sales',
        color='region',
        size='profit',
        save_as='customers_sales'
    )
    visualizations['customers_sales'] = scatter_charts
    
    # 3. 热力图
    correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'sales', 'profit', 'customers']].corr()
    heatmap_charts = viz_manager.create_visualization(
        data=correlation_matrix,
        chart_type='heatmap',
        title='Correlation Matrix',
        save_as='correlation_matrix'
    )
    visualizations['correlation_matrix'] = heatmap_charts
    
    # 导出可视化图表
    export_paths = viz_manager.export_visualizations(visualizations)
    print(f"导出的可视化图表: {export_paths}")
    
    # 创建仪表盘
    dashboard_config = [
        {
            'type': 'control',
            'id': 'region-filter',
            'title': '区域筛选',
            'control_type': 'dropdown',
            'options': ['All'] + list(data['region'].unique()),
            'value': 'All',
            'width': 3
        },
        {
            'type': 'graph',
            'id': 'monthly-sales-chart',
            'title': '月度销售额',
            'chart_type': 'bar_chart',
            'data': monthly_sales,
            'params': {
                'x': 'month',
                'y': 'sales',
                'title': 'Monthly Sales',
                'xlabel': 'Month',
                'ylabel': 'Total Sales'
            },
            'width': 6
        },
        {
            'type': 'graph',
            'id': 'customer-sales-chart',
            'title': '客户数量与销售额关系',
            'chart_type': 'scatter',
            'params': {
                'x': 'customers',
                'y': 'sales',
                'title': 'Relationship between Number of Customers and Sales',
                'xlabel': 'Number of Customers',
                'ylabel': 'Sales',
                'color': 'region',
                'size': 'profit'
            },
            'width': 6
        },
        {
            'type': 'graph',
            'id': 'correlation-matrix',
            'title': '相关性矩阵',
            'chart_type': 'heatmap',
            'data': correlation_matrix,
            'params': {
                'title': 'Correlation Matrix'
            },
            'width': 12
        }
    ]
    
    dashboard = viz_manager.create_dashboard(data, dashboard_config, title="销售数据分析仪表盘")
    
    if dashboard:
        print("创建仪表盘成功,运行 dashboard.run_server() 启动仪表盘")
    
    return {
        'data': data,
        'visualizations': visualizations,
        'viz_manager': viz_manager,
        'dashboard': dashboard
    }

if __name__ == "__main__":
    visualization_manager_example()