《昇思25天学习打卡营第02天|快速入门》

发布于:2024-06-28 ⋅ 阅读:(9) ⋅ 点赞:(0)

快速入门

  1. 准备MindSpore环境
    • 安装环境
       !pip uninstall mindspore -y
       !pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
    
    • 环境导入
    import mindspore
    from mindspore import nn #神经网络模块
    from mindspore.dataset import vision, transforms #图像模块(vision)&数据增强+预处理(transforms)
    from mindspore.dataset import MnistDataset #数据集模块
    
  2. 数据集处理
    #数据集为开放数据集,下载完成后解压
    from download import download
    url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
          "notebook/datasets/MNIST_Data.zip"
    path = download(url, "./", kind="zip", replace=True)
    MNIST数据集目录结构如下:
    	MNIST_Data
    	└── train
    	    ├── train-images-idx3-ubyte (60000个训练图片)
    	    ├── train-labels-idx1-ubyte (60000个训练标签)
    	└── test
    	    ├── t10k-images-idx3-ubyte (10000个测试图片)
    	    ├── t10k-labels-idx1-ubyte (10000个测试标签)
    	数据下载完成后,获得数据集对象。
    
    #获取训练&测试数据集
    train_dataset = MnistDataset('MNIST_Data/train')
    test_dataset = MnistDataset('MNIST_Data/test')
    
    #数据集预处理
    def datapipe(dataset, batch_size):
    	#定义图像转换操作
        image_transforms = [
        	#图片缩放至范围(0,1)
            vision.Rescale(1.0 / 255.0, 0), 
            #将图像数据的分布调整到均值为0,标准差为1的分布
            vision.Normalize(mean=(0.1307,), std=(0.3081,)), 
          	#HWC(Height, Width, Channel) -> (Channel, Height, Width)
            vision.HWC2CHW()
        ]
        #将标签的数据类型转化成int32,便于模型进行处理,目的是确保计算过程中的精度&性能
        label_transform = transforms.TypeCast(mindspore.int32)
    	
    	#对图像&标签进行数据转化
        dataset = dataset.map(image_transforms, 'image')
        dataset = dataset.map(label_transform, 'label')
        
        #分batch,便于后续训练
        dataset = dataset.batch(batch_size)
        return dataset
    
  3. 网络模型构建
    #定义模型
    class Network(nn.Cell):
       def __init__(self):
           super().__init__()
           self.flatten = nn.Flatten()
           self.dense_relu_sequential = nn.SequentialCell(
               nn.Dense(28*28, 512),
               nn.ReLU(),
               nn.Dense(512, 512),
               nn.ReLU(),
               nn.Dense(512, 10)
           )
    
       def construct(self, x):
           x = self.flatten(x)
           logits = self.dense_relu_sequential(x)
           return logits
    
    model = Network()
    
  4. 模型训练
    #定义损失函数&优化器
    loss_fn = nn.CrossEntropyLoss()
    optimizer = nn.SGD(model.trainable_params(), 1e-2)
    
    #正向计算
    def forward_fn(data, label):
       logits = model(data)
       loss = loss_fn(logits, label)
       return loss, logits
    
    #反向传播
    grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
    
    #单步训练
    def train_step(data, label):
       (loss, _), grads = grad_fn(data, label)
       optimizer(grads)
       return loss
       
    #训练过程
    def train(model, dataset):
       size = dataset.get_dataset_size()
       model.set_train()
       for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
           loss = train_step(data, label)
    
           if batch % 100 == 0:
               loss, current = loss.asnumpy(), batch
               print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
    
  5. 测试函数(模型评估)
    def test(model, dataset, loss_fn):
        num_batches = dataset.get_dataset_size()
        model.set_train(False)
        total, test_loss, correct = 0, 0, 0
        for data, label in dataset.create_tuple_iterator():
            pred = model(data)
            total += len(data)
            test_loss += loss_fn(pred, label).asnumpy()
            correct += (pred.argmax(1) == label).asnumpy().sum()
        test_loss /= num_batches
        correct /= total
        print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    
  6. 模型保存&加载&推理
    #保存训练完后模型
    mindspore.save_checkpoint(model, "model.ckpt")
    
    #实例模型
    model = Network()
    #加载模型
    param_dict = mindspore.load_checkpoint("model.ckpt")
    param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
    #推理
    model.set_train(False)
    for data, label in test_dataset:
        pred = model(data)
        predicted = pred.argmax(1)
        print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
        break
    

网站公告

今日签到

点亮在社区的每一天
去签到