一、使用说明
1.初始化模型类:
def __init__(self, db_url: str, model_classes: Dict[str, Type[Base]]):
"""
初始化导入器
:param db_url: 数据库连接URL (e.g. 'sqlite:///mydatabase.db')
:param model_classes: 模型类字典 {表名: 模型类}
"""
self.db_url = db_url
self.model_classes = model_classes
self.engine = self._create_engine()
self.Session = sessionmaker(bind=self.engine)
将所有数据库模型类定义在单独的文件中
确保每个模型类继承自Base
2.创建导入器实例:
importer = DatabaseImporter(
db_url='sqlite:///mydatabase.db',
model_classes={
'talent': Talent,
'enterprise': Enterprise
}
)
3.初始化数据库:
importer.initialize_database()
4.导入Excel数据:
importer.import_from_excel(
excel_path='path/to/file.xlsx',
model_name='talent',
column_mapping={
'Excel列名': '模型属性名',
# ...
}
)
二、源代码
import os
import logging
from typing import Type, Dict, Optional
import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.pool import QueuePool
from sqlalchemy.exc import SQLAlchemyError
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Base = declarative_base()
class DatabaseImporter:
"""
可复用的数据库导入工具类
功能:
- 自动创建数据库连接
- 处理Excel合并单元格
- 数据清洗和验证
- 批量数据导入
"""
def __init__(self, db_url: str, model_classes: Dict[str, Type[Base]]):
"""
初始化导入器
:param db_url: 数据库连接URL (e.g. 'sqlite:///mydatabase.db')
:param model_classes: 模型类字典 {表名: 模型类}
"""
self.db_url = db_url
self.model_classes = model_classes
self.engine = self._create_engine()
self.Session = sessionmaker(bind=self.engine)
def _create_engine(self):
"""创建数据库引擎"""
return create_engine(
self.db_url,
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_timeout=30,
pool_recycle=1800
)
def initialize_database(self):
"""初始化数据库表结构"""
try:
Base.metadata.create_all(self.engine)
logger.info("数据库表创建成功!")
except Exception as e:
logger.error(f"创建数据库表失败: {str(e)}")
raise
@staticmethod
def _clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""数据清洗处理"""
# 处理合并单元格
df = df.ffill()
# 将NaN转换为None
return df.where(pd.notnull(df), None)
def import_from_excel(
self,
excel_path: str,
model_name: str,
column_mapping: Optional[Dict[str, str]] = None,
batch_size: int = 100
) -> int:
"""
从Excel导入数据到指定模型
:param excel_path: Excel文件路径
:param model_name: 模型名称(在model_classes中定义的)
:param column_mapping: 列名映射 {Excel列名: 模型属性名}
:param batch_size: 批量插入大小
:return: 成功导入的记录数
"""
if model_name not in self.model_classes:
raise ValueError(f"未定义的模型名称: {model_name}")
model_class = self.model_classes[model_name]
session = self.Session()
records_imported = 0
try:
# 1. 读取Excel文件
logger.info(f"开始读取Excel文件: {excel_path}")
df = pd.read_excel(excel_path)
df = self._clean_dataframe(df)
# 2. 准备列名映射
if column_mapping is None:
column_mapping = {col: col for col in df.columns}
# 3. 验证列名
model_columns = {column.name for column in model_class.__table__.columns
if column.name != 'id'}
missing_columns = model_columns - set(column_mapping.values())
if missing_columns:
logger.warning(f"模型需要但Excel中缺少的列: {missing_columns}")
# 4. 准备批量插入数据
objects = []
for _, row in df.iterrows():
data = {
model_col: row.get(excel_col)
for excel_col, model_col in column_mapping.items()
if model_col in model_columns
}
objects.append(model_class(**data))
# 批量提交
if len(objects) >= batch_size:
session.bulk_save_objects(objects)
records_imported += len(objects)
objects = []
logger.info(f"已导入 {records_imported} 条记录...")
# 提交剩余记录
if objects:
session.bulk_save_objects(objects)
records_imported += len(objects)
session.commit()
logger.info(f"成功导入 {records_imported} 条记录到 {model_name} 表")
return records_imported
except Exception as e:
session.rollback()
logger.error(f"导入数据失败: {str(e)}")
raise
finally:
session.close()
# 示例用法
if __name__ == '__main__':
# 1. 定义你的模型类 (可以放在单独的文件中)
class Talent(Base):
__tablename__ = 'talent_person'
id = Column(Integer, primary_key=True, autoincrement=True, unique=True)
name = Column(String(255), nullable=False, comment='姓名')
workplace = Column(String(255), comment='工作单位')
# 其他字段...
class Enterprise(Base):
__tablename__ = 'enterprise'
enterprise_id = Column(Integer, primary_key=True, autoincrement=True)
supplier_name = Column(String(255), nullable=False)
# 其他字段...
# 2. 配置导入器
current_dir = os.path.dirname(os.path.abspath(__file__))
db_path = os.path.join(current_dir, 'person_enterprise.db')
importer = DatabaseImporter(
db_url=f'sqlite:///{db_path}',
model_classes={
'talent': Talent,
'enterprise': Enterprise
# 添加更多模型...
}
)
# 3. 初始化数据库
importer.initialize_database()
# 4. 导入Excel数据
try:
# 导入人才数据
importer.import_from_excel(
excel_path=r"C:\Users\lenovo\Desktop\.xlsx",
model_name='talent',
column_mapping={
'姓名': 'name',
'工作单位': 'workplace',
# 其他列映射...
}
)
# 可以继续导入其他Excel文件到不同表
# importer.import_from_excel(...)
except Exception as e:
logger.error(f"导入过程出错: {str(e)}")
三、注释
@staticmethod 是 Python 中的一个装饰器,用于定义静态方法。
属于类而非实例的方法,不需要访问实例属性 (self) 或类属性 (cls),本质上是一个放在类命名空间内的普通函数。