深度学习领域的经典案例:ResNet50图像分类研究(二)

发布于:2025-08-07 ⋅ 阅读:(32) ⋅ 点赞:(0)

目录

研究内容

步骤一、实验准备

步骤二、数据准备

步骤三、导入Python库&模块并配置运行信息

步骤四、 定义参数变量


研究内容

步骤一、实验准备

同时希望你拥有Python编码基础和概率、矩阵等基础数学知识。

推荐环境:

版本:MindSpore 2.2及以上

编程语言:Python 3.7及以上

步骤二、数据准备

示例中用到的图像花卉数据集,该数据集是开源数据集,总共包括5种花的类型:分别是daisy(雏菊,633张),dandelion(蒲公英,898张),roses(玫瑰,641张),sunflowers(向日葵,699张),tulips(郁金香,799张),保存在5个文件夹当中,总共3670张,大小大概在230M左右。为了在模型部署上线之后进行测试,数据集在这里分成了flower_photos_train和flower_photos_test两部分。

目录结构如下(数据集中flowers文件夹下):

flowers

├── flower_photos_train

├── daisy

├── dandelion

├── roses

├── sunflowers

├── tulips

├── LICENSE.txt

├── flower_photos_test

├── daisy

├── dandelion

├── roses

├── sunflowers

├── tulips

├── LICENSE.txt

步骤三、导入Python库&模块并配置运行信息

可以通过context.set_context来配置运行需要的信息,譬如运行模式、后端信息、硬件等信息。导入context模块,配置运行需要的信息。

from easydict import EasyDict as edict
# 字典访问,用来存储超参数
import os
# os模块主要用于处理文件和目录
import numpy as np
# 科学计算库
import matplotlib.pyplot as plt
# 绘图库

import mindspore
# MindSpore库
import mindspore.dataset as ds
# 数据集处理模块
from mindspore.dataset.vision import c_transforms as vision
# 图像增强模块
from mindspore import context
# 环境设置模块
import mindspore.nn as nn
# 神经网络模块
from mindspore.train import Model
# 模型编译
from mindspore.nn.optim.momentum import Momentum
# 动量优化器
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
# 模型保存设置
from mindspore import Tensor
# 张量
from mindspore.train.serialization import export
# 模型导出
from mindspore.train.loss_scale_manager import FixedLossScaleManager
# 损失值平滑处理
from mindspore.train.serialization import load_checkpoint, load_param_into_net
# 模型加载
import mindspore.ops as ops
# 常见算子操作

# 设置MindSpore的执行模式和设备
context.set_context(mode=context.GRAPH_MODE, device_target="CPU"

步骤四、 定义参数变量

edict中存放的是模型训练和测试中所需要的各种参数配置。

cfg = edict({
    'data_path': 'flowers/flower_photos_train',   #训练数据集路径
    'test_path':'flowers/flower_photos_test',     #测试数据集路径
    'data_size': 3616,
    'HEIGHT': 224,  # 图片高度
    'WIDTH': 224,  # 图片宽度
    '_R_MEAN': 123.68, # CIFAR10的均值
    '_G_MEAN': 116.78,
    '_B_MEAN': 103.94,
    '_R_STD': 1, # 自定义的标准差
    '_G_STD': 1,
    '_B_STD':1,
    '_RESIZE_SIDE_MIN': 256, # 图像增强resize最小值
    '_RESIZE_SIDE_MAX': 512,
    
    'batch_size': 32, # 批次大小
    'num_class': 5,     # 分类类别
    'epoch_size': 5,  # 训练次数
    'loss_scale_num':1024,
    
    'prefix': 'resnet-ai',  # 模型保存的名称
    'directory': './model_resnet',  # 模型保存的路径
    'save_checkpoint_steps': 10, # 每隔10步保存ckpt
})


网站公告

今日签到

点亮在社区的每一天
去签到