【软件开发】可复用的数据库导入工具类

发布于:2025-03-31 ⋅ 阅读:(23) ⋅ 点赞:(0)

在这里插入图片描述

一、使用说明


1.初始化模型类:

    def __init__(self, db_url: str, model_classes: Dict[str, Type[Base]]):
        """
        初始化导入器
        
        :param db_url: 数据库连接URL (e.g. 'sqlite:///mydatabase.db')
        :param model_classes: 模型类字典 {表名: 模型类}
        """
        self.db_url = db_url
        self.model_classes = model_classes
        self.engine = self._create_engine()
        self.Session = sessionmaker(bind=self.engine)

将所有数据库模型类定义在单独的文件中
确保每个模型类继承自Base
2.创建导入器实例:

importer = DatabaseImporter(
    db_url='sqlite:///mydatabase.db',
    model_classes={
        'talent': Talent,
        'enterprise': Enterprise
    }
)

3.初始化数据库:

importer.initialize_database()

4.导入Excel数据:

importer.import_from_excel(
    excel_path='path/to/file.xlsx',
    model_name='talent',
    column_mapping={
        'Excel列名': '模型属性名',
        # ...
    }
)

二、源代码


import os
import logging
from typing import Type, Dict, Optional
import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.pool import QueuePool
from sqlalchemy.exc import SQLAlchemyError

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Base = declarative_base()

class DatabaseImporter:
    """
    可复用的数据库导入工具类
    
    功能:
    - 自动创建数据库连接
    - 处理Excel合并单元格
    - 数据清洗和验证
    - 批量数据导入
    """
    
    def __init__(self, db_url: str, model_classes: Dict[str, Type[Base]]):
        """
        初始化导入器
        
        :param db_url: 数据库连接URL (e.g. 'sqlite:///mydatabase.db')
        :param model_classes: 模型类字典 {表名: 模型类}
        """
        self.db_url = db_url
        self.model_classes = model_classes
        self.engine = self._create_engine()
        self.Session = sessionmaker(bind=self.engine)
        
    def _create_engine(self):
        """创建数据库引擎"""
        return create_engine(
            self.db_url,
            poolclass=QueuePool,
            pool_size=5,
            max_overflow=10,
            pool_timeout=30,
            pool_recycle=1800
        )
    
    def initialize_database(self):
        """初始化数据库表结构"""
        try:
            Base.metadata.create_all(self.engine)
            logger.info("数据库表创建成功!")
        except Exception as e:
            logger.error(f"创建数据库表失败: {str(e)}")
            raise
    
    @staticmethod
    def _clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
        """数据清洗处理"""
        # 处理合并单元格
        df = df.ffill()
        # 将NaN转换为None
        return df.where(pd.notnull(df), None)
    
    def import_from_excel(
        self,
        excel_path: str,
        model_name: str,
        column_mapping: Optional[Dict[str, str]] = None,
        batch_size: int = 100
    ) -> int:
        """
        从Excel导入数据到指定模型
        
        :param excel_path: Excel文件路径
        :param model_name: 模型名称(在model_classes中定义的)
        :param column_mapping: 列名映射 {Excel列名: 模型属性名}
        :param batch_size: 批量插入大小
        :return: 成功导入的记录数
        """
        if model_name not in self.model_classes:
            raise ValueError(f"未定义的模型名称: {model_name}")
            
        model_class = self.model_classes[model_name]
        session = self.Session()
        records_imported = 0
        
        try:
            # 1. 读取Excel文件
            logger.info(f"开始读取Excel文件: {excel_path}")
            df = pd.read_excel(excel_path)
            df = self._clean_dataframe(df)
            
            # 2. 准备列名映射
            if column_mapping is None:
                column_mapping = {col: col for col in df.columns}
            
            # 3. 验证列名
            model_columns = {column.name for column in model_class.__table__.columns 
                           if column.name != 'id'}
            missing_columns = model_columns - set(column_mapping.values())
            if missing_columns:
                logger.warning(f"模型需要但Excel中缺少的列: {missing_columns}")
            
            # 4. 准备批量插入数据
            objects = []
            for _, row in df.iterrows():
                data = {
                    model_col: row.get(excel_col)
                    for excel_col, model_col in column_mapping.items()
                    if model_col in model_columns
                }
                objects.append(model_class(**data))
                
                # 批量提交
                if len(objects) >= batch_size:
                    session.bulk_save_objects(objects)
                    records_imported += len(objects)
                    objects = []
                    logger.info(f"已导入 {records_imported} 条记录...")
            
            # 提交剩余记录
            if objects:
                session.bulk_save_objects(objects)
                records_imported += len(objects)
            
            session.commit()
            logger.info(f"成功导入 {records_imported} 条记录到 {model_name} 表")
            return records_imported
            
        except Exception as e:
            session.rollback()
            logger.error(f"导入数据失败: {str(e)}")
            raise
        finally:
            session.close()

# 示例用法
if __name__ == '__main__':
    # 1. 定义你的模型类 (可以放在单独的文件中)
    class Talent(Base):
        __tablename__ = 'talent_person'
        id = Column(Integer, primary_key=True, autoincrement=True, unique=True)
        name = Column(String(255), nullable=False, comment='姓名')
        workplace = Column(String(255), comment='工作单位')
        # 其他字段...
    
    class Enterprise(Base):
        __tablename__ = 'enterprise'
        enterprise_id = Column(Integer, primary_key=True, autoincrement=True)
        supplier_name = Column(String(255), nullable=False)
        # 其他字段...
    
    # 2. 配置导入器
    current_dir = os.path.dirname(os.path.abspath(__file__))
    db_path = os.path.join(current_dir, 'person_enterprise.db')
    
    importer = DatabaseImporter(
        db_url=f'sqlite:///{db_path}',
        model_classes={
            'talent': Talent,
            'enterprise': Enterprise
            # 添加更多模型...
        }
    )
    
    # 3. 初始化数据库
    importer.initialize_database()
    
    # 4. 导入Excel数据
    try:
        # 导入人才数据
        importer.import_from_excel(
            excel_path=r"C:\Users\lenovo\Desktop\.xlsx",
            model_name='talent',
            column_mapping={
                '姓名': 'name',
                '工作单位': 'workplace',
                # 其他列映射...
            }
        )
        
        # 可以继续导入其他Excel文件到不同表
        # importer.import_from_excel(...)
        
    except Exception as e:
        logger.error(f"导入过程出错: {str(e)}")

三、注释

@staticmethod 是 Python 中的一个装饰器,用于定义静态方法。
属于类而非实例的方法,不需要访问实例属性 (self) 或类属性 (cls),本质上是一个放在类命名空间内的普通函数。