MindSpore训练并测试UNet

发布于:2025-03-27 ⋅ 阅读:(34) ⋅ 点赞:(0)

一、MindSpore安装配置:

参考MindSpore安装 | 昇思MindSpore社区,很详细,但是Win下只能使用CPU,Linux下可以使用GPU但是只支持特定CUDA版本

二、UNet网络定义:

  • pytorch风格大致类似,nn.Cell对应pytorch中的nn.Module
  • 注意MindSpore卷积,如果存在padding,需要设置pad_mode='pad'
import mindspore.nn as nn
from mindspore import Tensor, context, ops
import mindspore
import os

class TransBlock(nn.Cell):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels)
        self.TransConv = nn.Conv2dTranspose(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
    def construct(self, x):
        out = self.TransConv(x)
        out = self.bn(out)
        out = self.relu(out)
        return out

class VGGBlock(nn.Cell):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3,pad_mode='pad', padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3,pad_mode='pad',padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def construct(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out

class TransConv_UNet(nn.Cell):
    def __init__(self,input_channels=3, num_classes=1,  **kwargs):
        super().__init__()
        self.n_channels = input_channels
        self.n_classes = num_classes
        nb_filter = [64, 128, 256, 512, 1024]

        self.pool = nn.MaxPool2d(2, 2)

        self.up_4 = TransBlock(nb_filter[4], nb_filter[4])
        self.up_3 = TransBlock(nb_filter[3], nb_filter[3])
        self.up_2 = TransBlock(nb_filter[2], nb_filter[2])
        self.up_1 = TransBlock(nb_filter[1], nb_filter[1])

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])

        self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)

    def construct(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x2_0 = self.conv2_0(self.pool(x1_0))
        x3_0 = self.conv3_0(self.pool(x2_0))
        x4_0 = self.conv4_0(self.pool(x3_0))

        x4_up = self.up_4(x4_0)
        x3_1 = self.conv3_1(ops.cat([x3_0, x4_up], 1))
        x2_2 = self.conv2_2(ops.cat([x2_0, self.up_3(x3_1)], 1))
        x1_3 = self.conv1_3(ops.cat([x1_0, self.up_2(x2_2)], 1))
        x0_4 = self.conv0_4(ops.cat([x0_0, self.up_1(x1_3)], 1))

        output = self.final(x0_4)
        return output

if __name__ == '__main__':
    model = TransConv_UNet(input_channels=3,num_classes=6)
    X = ops.ones((1,3, 384, 384), mindspore.float32)
    logits = model(X)
    print(logits.shape)
    for name, param in model.parameters_and_names():
        print(f"Layer: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")

三、模型训练:

数据扩充变换利用的是其他第三方库:albumentations

def train_model(
        model,
        device,
        save_path,
        epochs: int = 5,
        batch_size: int = 1,
        learning_rate: float = 1e-2,
        save_checkpoint: bool = True,
        img_scale: float = 0.5,
        weight_decay: float = 1e-4,
        momentum: float = 0.90,
        data_type:str='turbine',
) -> None:
    """模型训练函数

    Args:
        model: 待训练的模型
        device: 训练设备 (CPU/GPU)
        save_path: 模型保存路径
        epochs: 训练轮次
        batch_size: 批大小
        learning_rate: 学习率
        save_checkpoint: 是否保存检查点
        img_scale: 图像缩放比例
        weight_decay: 权重衰减系数
        momentum: 动量系数
        data_type: 数据类型 ('turbine'/'voc')
    """
    # 记录最佳的模型性能保存下来
    best_miou = 0
    # 数据增强
    train_transform = Compose([
        albu.RandomRotate90(),
        albu.Flip(),
        OneOf([
            transforms.HueSaturationValue(),
            transforms.RandomBrightness(),
            transforms.RandomContrast(),
        ], p=0.5),
        albu.Rotate(limit=(-65, 65), p=0.8),  # 随机旋转,角度范围为-30到30度,
        albu.RandomCrop(height=400, width=400, p=0.8),  # 添加随机裁剪操作
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        albu.Resize(width=384, height=384),
    ])

    # 初始化数据集和数据加载器
    train_set = create_dataset(
        args.train_image_dir,
        args.train_mask_dir,
        batch_size=batch_size,
        transform = train_transform
    )

    n_train = len(train_set)

    lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=n_train * epochs,
                            step_per_epoch=n_train, decay_epoch=epochs)

    if args.optimizer == 'SGD':
        optimizer = nn.SGD(
            model.trainable_params(),
            learning_rate=lr,
            momentum=momentum,
            nesterov=False,
            weight_decay=weight_decay
        )
    elif args.optimizer == 'Adam':
        optimizer = nn.Adam(
            model.trainable_params(),
            learning_rate=lr,
            weight_decay=weight_decay
        )


    # 损失函数
    loss_fn = nn.CrossEntropyLoss()

    def forward_fn(inputs, targets):
        logits = model(inputs)
        loss = loss_fn(logits, targets)
        return loss

    grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters)

    def train_step(inputs, targets):
        loss, grads = grad_fn(inputs, targets)
        optimizer(grads)
        return loss

    iter_num = 0
    for epoch in range(1, epochs + 1):
        # 设置模型的模式为训练模式
        model.set_train(True)
        with tqdm(total=n_train,dynamic_ncols=True, desc=f'Epoch {epoch} / {epochs}') as pbar:
            for data in train_dataset.create_dict_iterator():
                images = data["image"]
                true_masks = data["mask"]
                # 前向传播
                loss = train_step(images, true_masks)

                # 进度条更新
                pbar.update(1)
                pbar.set_postfix({"loss (batch)": loss.item()})
                iter_num += 1
    # 保存模型
    if save_checkpoint:
        # 创建文件夹(如果不存在的话)
        os.makedirs(save_path, exist_ok=True)
        # 保存模型的文件路径
        model_save_file = os.path.join(save_path, 'trasnform_seg_unet.pth')
        # 保存模型
        mindspore.save_checkpoint(model, "model.ckpt")
        print(f'Model saved to {model_save_file}')

四、模型测试:

def test_one_image():
    model_path = './model.ckpt'
    model = TransConv_UNet(input_channels=3, num_classes=6)
    param_dict = mindspore.load_checkpoint("./model.ckpt")
    param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
    test_image = cv2.imread('*.png')
    test_transform = Compose([
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        albu.Resize(width=384, height=384),
    ])
    transformed = test_transform(image=test_image)
    image = transformed['image']
    image = image.transpose(2,0,1)
    image = np.expand_dims(image, axis=0)
    print(image)
    image = mindspore.Tensor(image,dtype=mindspore.float32)
    out = model(image)
    output = ops.softmax(out, axis=1)
    output = ops.argmax(output, dim=1)
    output = ops.squeeze(output, axis=1)
    output = ops.squeeze(output, axis=0)
    prediction = output.asnumpy()
    base_dir = os.path.dirname(model_path)
    save_dir = os.path.join(base_dir, "ckpt_prediction.png")
    result_to_image(prediction,save_dir)

if __name__ == '__main__':
    test_one_image()


网站公告

今日签到

点亮在社区的每一天
去签到