mindspore打卡第几天 DDPM 之Unet 网络解析markdown版本

发布于:2024-07-01 ⋅ 阅读:(18) ⋅ 点赞:(0)

mindspore打卡第几天 DDPM 之Unet 网络解析markdown版本

A:
为啥DDPM的unet网络的下采样这部分的channel是从20 32 64 128这样上升的?从U形结构看不应该是下降的
{Block1 --> block2 --> Res(attn)-- >dowmsample}×3

B:
他是在weight和hight上是下降的,通道数是上升
在上采样部分反过来,weight和hight变大,通道数最后回到3

条件U-Net

我们已经定义了所有的构建块(位置嵌入、ResNet/ConvNeXT块、Attention和组归一化),现在需要定义整个神经网络了。请记住,网络 ϵ θ ( x t , t ) \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) ϵθ(xt,t) 的工作是接收一批噪声图像+噪声水平,并输出添加到输入中的噪声。

更具体的:
网络获取了一批(batch_size, num_channels, height, width)形状的噪声图像和一批(batch_size, 1)形状的噪音水平作为输入,并返回(batch_size, num_channels, height, width)形状的张量。

网络构建过程如下:

  • 首先,将卷积层应用于噪声图像批上,并计算噪声水平的位置

  • 接下来,应用一系列下采样级。每个下采样阶段由2个ResNet/ConvNeXT块 + groupnorm + attention + 残差连接 + 一个下采样操作组成

  • 在网络的中间,再次应用ResNet或ConvNeXT块,并与attention交织

  • 接下来,应用一系列上采样级。每个上采样级由2个ResNet/ConvNeXT块+ groupnorm + attention + 残差连接 + 一个上采样操作组成

  • 最后,应用ResNet/ConvNeXT块,然后应用卷积层

最终,神经网络将层堆叠起来,就像它们是乐高积木一样(但重要的是了解它们是如何工作的)。

class Unet(nn.Cell):
    def __init__(
            self,
            dim,
            init_dim=None,
            out_dim=None,
            dim_mults=(1, 2, 4, 8),
            channels=3,
            with_time_emb=True,
            convnext_mult=2,
    ):
        super().__init__()

        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ConvNextBlock, mult=convnext_mult)

        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.SequentialCell(
                SinusoidalPositionEmbeddings(dim),
                nn.Dense(dim, time_dim),
                nn.GELU(),
                nn.Dense(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        self.downs = nn.CellList([])
        self.ups = nn.CellList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.CellList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.CellList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.SequentialCell(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )

    def construct(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        len_h = len(h) - 1
        for block1, block2, attn, upsample in self.ups:
            x = ops.concat((x, h[len_h]), 1)
            len_h -= 1
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)
        return self.final_conv(x)
import mindspore as ms
from mindspore.common.initializer import Normal

# 参数定义
image_side_length = 32  # 图像的宽和高的像素数
channels = 3  # 图像通道数,这里假设处理的是RGB图像
batch_size = 2  # 批次大小

# 定义 Unet模型
# 注意:此处的dim应该根据模型设计具体指定,但基于您的代码,我们保持原样
unet_model = Unet(dim=image_side_length, channels=channels, dim_mults=(1, 2, 4,))

# 构建输入数据
x = ms.Tensor(shape=(batch_size, channels, image_side_length, image_side_length), dtype=ms.float32, init=Normal())
x.shape  # 显示数据形状
print(x)  # 打印数据(显示初始化后的随机值)
[[[[ 1.22990236e-02  9.65940859e-03 -5.95777121e-04 ... -1.09354462e-02
     2.30002552e-02 -5.25823655e-03]
   [ 1.35805225e-02  1.16471937e-02 -1.20973922e-02 ... -1.13204606e-02
    -1.91520341e-02 -1.09745166e-03]
   [-4.65569133e-03  1.33861918e-02 -1.60518996e-02 ...  4.18792450e-04
     9.22567211e-03  4.44417645e-04]
   ...
   [ 3.40697076e-03  4.53335233e-03  5.73999388e-03 ...  4.67619160e-03
    -8.16432573e-03 -1.39179081e-02]
   [-9.07978602e-03 -6.43689744e-03  1.32928183e-02 ...  4.21820907e-03
    -1.05559649e-02  8.33686162e-03]
   [ 2.96656298e-03 -7.44550209e-03  5.52403228e-03 ... -2.09826510e-03
     2.17068940e-02  2.28530783e-02]]

  [[-2.34551495e-03  7.68061494e-03  8.63175746e-03 ... -5.62175177e-03
    -9.85390134e-03 -4.08322597e-03]
   [ 1.30044697e-02 -9.87336412e-03  2.55680992e-03 ...  1.21581517e-02
     1.10829184e-02 -1.09381862e-02]
   [-1.09032113e-02  1.25320591e-02 -9.15124733e-03 ... -8.42134352e-04
    -3.48115107e-03 -8.12307373e-03]
   ...
   [-1.22983279e-02  2.11556954e-03 -1.63072231e-03 ... -8.83890502e-03
     2.00234205e-02 -2.91514886e-03]
   [-4.95374482e-03 -1.51413877e-03  6.57585217e-03 ...  1.93616766e-02
    -3.65696964e-03 -1.76955778e-02]
   [ 8.47856048e-03  9.17020999e-03 -5.66793000e-03 ... -2.92802905e-03
    -5.98460436e-03  8.32138583e-03]]

  [[ 1.00378189e-02 -2.43024575e-03  2.11097375e-02 ... -6.47504721e-03
    -1.47426147e-02  7.38033140e-03]
   [-3.09416349e-03 -3.46184568e-03 -7.74018466e-03 ...  1.19950040e-03
     3.14799254e-04 -7.95779750e-03]
   [ 3.98837449e-03  2.33123749e-02  1.63442008e-02 ...  1.05365906e-02
    -1.44729228e-03  1.90633966e-03]
   ...
   [-1.76522471e-02  9.42215510e-03 -9.92319733e-03 ... -8.83952528e-03
    -1.18930812e-03 -8.53374321e-03]
   [ 2.51283534e-02 -1.38457380e-02  1.32035371e-02 ...  1.66724548e-02
    -9.26751085e-03  1.42328264e-02]
   [-3.69384699e-03  6.09130273e-03 -2.94976344e-04 ...  7.72336172e-03
    -3.75742209e-03 -3.17590404e-03]]]


 [[[-2.92081665e-03 -1.39991604e-02 -8.93703103e-03 ...  1.51352473e-02
     3.90937366e-03  2.66693830e-02]
   [-2.27847677e-02  3.63694108e-03  2.70780316e-03 ... -8.13330431e-03
    -4.17956570e-03  1.22072157e-02]
   [-1.24624427e-02  4.75015305e-03  2.68556597e-03 ...  6.48784591e-03
    -6.09957753e-03  4.85362951e-03]
   ...
   [-3.67846363e-03 -9.81856976e-03 -7.40657933e-03 ...  1.95454084e-03
     1.80558003e-02  4.30267537e-03]
   [-2.47061905e-02  1.53471017e-03 -2.55961739e-03 ... -6.16029697e-03
    -1.19128199e-02  7.23672146e-03]
   [-9.77169070e-03 -5.93968621e-03 -1.16010886e-02 ...  1.13449963e-02
     7.74116023e-03 -8.25872459e-03]]

  [[ 2.42574494e-02 -1.59016773e-02  4.60586976e-03 ... -1.27300173e-02
    -2.08083801e-02  1.20891845e-02]
   [ 4.98928130e-03  1.58587005e-02 -1.17553072e-02 ... -4.57813032e-03
     2.66204093e-04 -1.80527139e-02]
   [ 9.97055881e-03  2.07035127e-03 -7.31401029e-04 ...  1.80852767e-02
    -2.09929375e-03  4.49541025e-04]
   ...
   [-8.71989876e-04  7.75372284e-03  3.14102072e-05 ...  6.37980178e-04
    -1.68553423e-02 -4.13572555e-03]
   [ 6.12246012e-03 -1.88669516e-03  1.50548946e-02 ...  9.18534491e-03
     1.46157937e-02  5.96544426e-03]
   [-5.24167530e-03  2.64895801e-03  7.25612324e-03 ... -5.48065547e-03
    -2.98001780e-03 -7.99621455e-03]]

  [[ 1.18518099e-02  1.00414380e-02 -3.00463289e-03 ... -3.48429219e-03
     1.21912286e-02 -8.21612682e-03]
   [ 9.25556850e-03 -1.57560236e-04  7.71128759e-03 ...  3.91136715e-03
     1.56383701e-02  8.09505815e-04]
   [ 4.79864981e-03  1.88933630e-02  1.73798949e-02 ...  5.97322173e-03
     4.30198200e-03  1.52684944e-02]
   ...
   [-9.37487371e-03  5.54391975e-03  4.64118691e-03 ...  6.41342625e-03
     1.36971334e-03 -1.25444317e-02]
   [-4.26448090e-03  7.79700419e-03  2.39845295e-03 ... -1.18866842e-02
     3.74738523e-03  1.07039241e-02]
   [-1.02939839e-02  7.36899953e-03 -2.00587343e-02 ... -1.10042403e-02
    -1.42604960e-02 -1.37462756e-02]]]]
x.shape  # 显示数据形状
(2, 3, 32, 32)
dim=image_side_length
channels=channels
dim_mults=(1, 2, 4,)

init_dim=None
out_dim=None
# dim_mults=(1, 2, 4, 8)
channels=3
with_time_emb=True
convnext_mult=2
init_dim = default(init_dim, dim // 3 * 2)
dim,init_dim
(32, 20)
channels
3
init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)
init_conv 
Conv2d<input_channels=3, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2d30>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefedb2be0>, format=NCHW>
dim, dim_mults,init_dim
(32, (1, 2, 4), 20)
(lambda m: dim * m, dim_mults)
(<function __main__.<lambda>(m)>, (1, 2, 4))
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
dims
[20, 32, 64, 128]
zip(dims[:-1], dims[1:])
<zip at 0xfffefc367b80>
dims[:-1], dims[1:]
([20, 32, 64], [32, 64, 128])
in_out = list(zip(dims[:-1], dims[1:]))
in_out
[(20, 32), (32, 64), (64, 128)]
ConvNextBlock,convnext_mult
(__main__.ConvNextBlock, 2)
block_klass = partial(ConvNextBlock, mult=convnext_mult) ##传入ConvNextBlock的第一个参数mult=convnext_mult
block_klass
functools.partial(<class '__main__.ConvNextBlock'>, mult=2)
 with_time_emb
True

time_dim = dim * 4
time_dim,dim
(128, 32)
time_mlp = nn.SequentialCell(
                SinusoidalPositionEmbeddings(dim),
                nn.Dense(dim, time_dim),
                nn.GELU(),
                nn.Dense(time_dim, time_dim),
            )
time_mlp
SequentialCell<
  (0): SinusoidalPositionEmbeddings<>
  (1): Dense<input_channels=32, output_channels=128, has_bias=True>
  (2): GELU<>
  (3): Dense<input_channels=128, output_channels=128, has_bias=True>
  >
downs = nn.CellList([])
ups = nn.CellList([])
ups
CellList<>
num_resolutions = len(in_out)
num_resolutions
3
for ind, (dim_in, dim_out) in enumerate(in_out):
        is_last = ind >= (num_resolutions - 1)
        print(ind,":",is_last)
0 : False
1 : False
2 : True
dim_in, dim_out, time_dim  ###把每个时间步编码为128维度
(64, 128, 128)
in_out
[(20, 32), (32, 64), (64, 128)]
for ind, (dim_in, dim_out) in enumerate(in_out):
    print(dim_in, dim_out)
    is_last = ind >= (num_resolutions - 1)

    downs.append(
        nn.CellList(
            [
                block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                Downsample(dim_out) if not is_last else nn.Identity(),
            ]
        )
    )
20 32
32 64
64 128
downs
CellList<
  (0): CellList<
    (0): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=20, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c130>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=20>
        (1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc070>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=64>
        (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354a90>, bias_init=None, format=NCHW>
        >
      (res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3d7cd0>, bias_init=None, format=NCHW>
      >
    (1): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=32, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3c17f0>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=32>
        (1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3062b0>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=64>
        (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c190>, bias_init=None, format=NCHW>
        >
      (res_conv): Identity<>
      >
    (2): Residual<
      (fn): PreNorm<
        (fn): LinearAttention<
          (to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d90>, bias_init=None, format=NCHW>
          (to_out): SequentialCell<
            (0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d00>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc306a60>, format=NCHW>
            (1): LayerNorm<>
            >
          >
        (norm): GroupNorm<num_groups=1, num_channels=32>
        >
      >
    (3): Conv2d<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf1f0>, bias_init=None, format=NCHW>
    >
  (1): CellList<
    (0): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=32, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306fd0>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=32>
        (1): Conv2d<input_channels=32, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf580>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=128>
        (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf190>, bias_init=None, format=NCHW>
        >
      (res_conv): Conv2d<input_channels=32, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf670>, bias_init=None, format=NCHW>
      >
    (1): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=64, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf940>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=64>
        (1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf910>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=128>
        (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf760>, bias_init=None, format=NCHW>
        >
      (res_conv): Identity<>
      >
    (2): Residual<
      (fn): PreNorm<
        (fn): LinearAttention<
          (to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfa30>, bias_init=None, format=NCHW>
          (to_out): SequentialCell<
            (0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfac0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2cfb80>, format=NCHW>
            (1): LayerNorm<>
            >
          >
        (norm): GroupNorm<num_groups=1, num_channels=64>
        >
      >
    (3): Conv2d<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfdc0>, bias_init=None, format=NCHW>
    >
  (2): CellList<
    (0): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=64, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfe50>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=64>
        (1): Conv2d<input_channels=64, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cff40>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=256>
        (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac130>, bias_init=None, format=NCHW>
        >
      (res_conv): Conv2d<input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac250>, bias_init=None, format=NCHW>
      >
    (1): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=128, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac1f0>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=128>
        (1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac370>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=256>
        (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac4c0>, bias_init=None, format=NCHW>
        >
      (res_conv): Identity<>
      >
    (2): Residual<
      (fn): PreNorm<
        (fn): LinearAttention<
          (to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac6a0>, bias_init=None, format=NCHW>
          (to_out): SequentialCell<
            (0): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac7c0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2ac820>, format=NCHW>
            (1): LayerNorm<>
            >
          >
        (norm): GroupNorm<num_groups=1, num_channels=128>
        >
      >
    (3): Identity<>
    >
  >
mid_dim = dims[-1]
mid_dim 
128
mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
mid_block1
ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=128, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354b50>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=128>
    (1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17e50>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=256>
    (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac36db20>, bias_init=None, format=NCHW>
    >
  (res_conv): Identity<>
  >
mid_attn
Residual<
  (fn): PreNorm<
    (fn): Attention<
      (to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac36d9a0>, bias_init=None, format=NCHW>
      (to_out): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac36d880>, bias_init=<mindspore.common.initializer.Uniform object at 0xffffac36d730>, format=NCHW>
      >
    (norm): GroupNorm<num_groups=1, num_channels=128>
    >
  >
mid_block2 
ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=128, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30cee0>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=128>
    (1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c850>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=256>
    (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c910>, bias_init=None, format=NCHW>
    >
  (res_conv): Identity<>
  >
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
    print(dim_in, dim_out)
    is_last = ind >= (num_resolutions - 1)
    print(is_last)
64 128
False
32 64
False
dim_in
32
LinearAttention(dim_in)
LinearAttention<
  (to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306eb0>, bias_init=None, format=NCHW>
  (to_out): SequentialCell<
    (0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac348c40>, bias_init=<mindspore.common.initializer.Uniform object at 0xffffac2bc4c0>, format=NCHW>
    (1): LayerNorm<>
    >
  >
class LinearAttention(nn.Cell):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)

        self.to_out = nn.SequentialCell(
            nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),
            LayerNorm(dim)
        )

        self.map = ops.Map()
        self.partial = ops.Partial()

    def construct(self, x):
        b, _, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, 1)
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)

        q = ops.softmax(q, -2)
        k = ops.softmax(k, -1)

        q = q * self.scale
        v = v / (h * w)

        # 'b h d n, b h e n -> b h d e'
        context = ops.bmm(k, v.swapaxes(2, 3))
        # 'b h d e, b h d n -> b h e n'
        out = ops.bmm(context.swapaxes(2, 3), q)

        out = out.reshape((b, -1, h, w))
        return self.to_out(out)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
    is_last = ind >= (num_resolutions - 1)

    ups.append(
        nn.CellList(
            [
                block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                Upsample(dim_in) if not is_last else nn.Identity(),
            ]
        )
    )
ups
CellList<
  (0): CellList<
    (0): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=256, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=256, output_channels=256, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=256, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2b80>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=256>
        (1): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee15430>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=128>
        (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee15d30>, bias_init=None, format=NCHW>
        >
      (res_conv): Conv2d<input_channels=256, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb24c0>, bias_init=None, format=NCHW>
      >
    (1): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=64, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbcd00>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=64>
        (1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc550>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=128>
        (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c4f0>, bias_init=None, format=NCHW>
        >
      (res_conv): Identity<>
      >
    (2): Residual<
      (fn): PreNorm<
        (fn): LinearAttention<
          (to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354580>, bias_init=None, format=NCHW>
          (to_out): SequentialCell<
            (0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30cc10>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc30c9a0>, format=NCHW>
            (1): LayerNorm<>
            >
          >
        (norm): GroupNorm<num_groups=1, num_channels=64>
        >
      >
    (3): Conv2dTranspose<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30cac0>, bias_init=None, format=NCHW>
    >
  (1): CellList<
    (0): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=128, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c2e0>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=128>
        (1): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c5e0>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=64>
        (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17f10>, bias_init=None, format=NCHW>
        >
      (res_conv): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17760>, bias_init=None, format=NCHW>
      >
    (1): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=32, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17a00>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=32>
        (1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17580>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=64>
        (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17dc0>, bias_init=None, format=NCHW>
        >
      (res_conv): Identity<>
      >
    (2): Residual<
      (fn): PreNorm<
        (fn): LinearAttention<
          (to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3064f0>, bias_init=None, format=NCHW>
          (to_out): SequentialCell<
            (0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306b50>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc3065b0>, format=NCHW>
            (1): LayerNorm<>
            >
          >
        (norm): GroupNorm<num_groups=1, num_channels=32>
        >
      >
    (3): Conv2dTranspose<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306f70>, bias_init=None, format=NCHW>
    >
  >
out_dim = default(out_dim, channels)

out_dim 
3
final_conv = nn.SequentialCell(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )
final_conv 
SequentialCell<
  (0): ConvNextBlock<
    (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2acb50>, bias_init=None, format=NCHW>
    (net): SequentialCell<
      (0): GroupNorm<num_groups=1, num_channels=32>
      (1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17bb0>, bias_init=None, format=NCHW>
      (2): GELU<>
      (3): GroupNorm<num_groups=1, num_channels=64>
      (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2100>, bias_init=None, format=NCHW>
      >
    (res_conv): Identity<>
    >
  (1): Conv2d<input_channels=32, output_channels=3, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2eb0>, bias_init=None, format=NCHW>
  >
x
time=5
x.shape
(2, 3, 32, 32)
print(x)
[[[[ 1.22990236e-02  9.65940859e-03 -5.95777121e-04 ... -1.09354462e-02
     2.30002552e-02 -5.25823655e-03]
   [ 1.35805225e-02  1.16471937e-02 -1.20973922e-02 ... -1.13204606e-02
    -1.91520341e-02 -1.09745166e-03]
   [-4.65569133e-03  1.33861918e-02 -1.60518996e-02 ...  4.18792450e-04
     9.22567211e-03  4.44417645e-04]
   ...
   [ 3.40697076e-03  4.53335233e-03  5.73999388e-03 ...  4.67619160e-03
    -8.16432573e-03 -1.39179081e-02]
   [-9.07978602e-03 -6.43689744e-03  1.32928183e-02 ...  4.21820907e-03
    -1.05559649e-02  8.33686162e-03]
   [ 2.96656298e-03 -7.44550209e-03  5.52403228e-03 ... -2.09826510e-03
     2.17068940e-02  2.28530783e-02]]

  [[-2.34551495e-03  7.68061494e-03  8.63175746e-03 ... -5.62175177e-03
    -9.85390134e-03 -4.08322597e-03]
   [ 1.30044697e-02 -9.87336412e-03  2.55680992e-03 ...  1.21581517e-02
     1.10829184e-02 -1.09381862e-02]
   [-1.09032113e-02  1.25320591e-02 -9.15124733e-03 ... -8.42134352e-04
    -3.48115107e-03 -8.12307373e-03]
   ...
   [-1.22983279e-02  2.11556954e-03 -1.63072231e-03 ... -8.83890502e-03
     2.00234205e-02 -2.91514886e-03]
   [-4.95374482e-03 -1.51413877e-03  6.57585217e-03 ...  1.93616766e-02
    -3.65696964e-03 -1.76955778e-02]
   [ 8.47856048e-03  9.17020999e-03 -5.66793000e-03 ... -2.92802905e-03
    -5.98460436e-03  8.32138583e-03]]

  [[ 1.00378189e-02 -2.43024575e-03  2.11097375e-02 ... -6.47504721e-03
    -1.47426147e-02  7.38033140e-03]
   [-3.09416349e-03 -3.46184568e-03 -7.74018466e-03 ...  1.19950040e-03
     3.14799254e-04 -7.95779750e-03]
   [ 3.98837449e-03  2.33123749e-02  1.63442008e-02 ...  1.05365906e-02
    -1.44729228e-03  1.90633966e-03]
   ...
   [-1.76522471e-02  9.42215510e-03 -9.92319733e-03 ... -8.83952528e-03
    -1.18930812e-03 -8.53374321e-03]
   [ 2.51283534e-02 -1.38457380e-02  1.32035371e-02 ...  1.66724548e-02
    -9.26751085e-03  1.42328264e-02]
   [-3.69384699e-03  6.09130273e-03 -2.94976344e-04 ...  7.72336172e-03
    -3.75742209e-03 -3.17590404e-03]]]


 [[[-2.92081665e-03 -1.39991604e-02 -8.93703103e-03 ...  1.51352473e-02
     3.90937366e-03  2.66693830e-02]
   [-2.27847677e-02  3.63694108e-03  2.70780316e-03 ... -8.13330431e-03
    -4.17956570e-03  1.22072157e-02]
   [-1.24624427e-02  4.75015305e-03  2.68556597e-03 ...  6.48784591e-03
    -6.09957753e-03  4.85362951e-03]
   ...
   [-3.67846363e-03 -9.81856976e-03 -7.40657933e-03 ...  1.95454084e-03
     1.80558003e-02  4.30267537e-03]
   [-2.47061905e-02  1.53471017e-03 -2.55961739e-03 ... -6.16029697e-03
    -1.19128199e-02  7.23672146e-03]
   [-9.77169070e-03 -5.93968621e-03 -1.16010886e-02 ...  1.13449963e-02
     7.74116023e-03 -8.25872459e-03]]

  [[ 2.42574494e-02 -1.59016773e-02  4.60586976e-03 ... -1.27300173e-02
    -2.08083801e-02  1.20891845e-02]
   [ 4.98928130e-03  1.58587005e-02 -1.17553072e-02 ... -4.57813032e-03
     2.66204093e-04 -1.80527139e-02]
   [ 9.97055881e-03  2.07035127e-03 -7.31401029e-04 ...  1.80852767e-02
    -2.09929375e-03  4.49541025e-04]
   ...
   [-8.71989876e-04  7.75372284e-03  3.14102072e-05 ...  6.37980178e-04
    -1.68553423e-02 -4.13572555e-03]
   [ 6.12246012e-03 -1.88669516e-03  1.50548946e-02 ...  9.18534491e-03
     1.46157937e-02  5.96544426e-03]
   [-5.24167530e-03  2.64895801e-03  7.25612324e-03 ... -5.48065547e-03
    -2.98001780e-03 -7.99621455e-03]]

  [[ 1.18518099e-02  1.00414380e-02 -3.00463289e-03 ... -3.48429219e-03
     1.21912286e-02 -8.21612682e-03]
   [ 9.25556850e-03 -1.57560236e-04  7.71128759e-03 ...  3.91136715e-03
     1.56383701e-02  8.09505815e-04]
   [ 4.79864981e-03  1.88933630e-02  1.73798949e-02 ...  5.97322173e-03
     4.30198200e-03  1.52684944e-02]
   ...
   [-9.37487371e-03  5.54391975e-03  4.64118691e-03 ...  6.41342625e-03
     1.36971334e-03 -1.25444317e-02]
   [-4.26448090e-03  7.79700419e-03  2.39845295e-03 ... -1.18866842e-02
     3.74738523e-03  1.07039241e-02]
   [-1.02939839e-02  7.36899953e-03 -2.00587343e-02 ... -1.10042403e-02
    -1.42604960e-02 -1.37462756e-02]]]]
init_conv
Conv2d<input_channels=3, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2d30>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefedb2be0>, format=NCHW>
x = init_conv(x)
print(x)    
[[[[ 0.04106301  0.04308875  0.03753628 ...  0.03978373  0.04020362
     0.03793241]
   [ 0.04131693  0.04225756  0.03624208 ...  0.04158005  0.04384746
     0.0374498 ]
   [ 0.03251923  0.04382189  0.03682814 ...  0.04343156  0.03790728
     0.03667009]
   ...
   [ 0.03987011  0.04261249  0.03721504 ...  0.03754282  0.03530194
     0.04190454]
   [ 0.03995614  0.04259038  0.04231969 ...  0.03937387  0.03802945
     0.03542861]
   [ 0.03724413  0.03895703  0.03808391 ...  0.04210365  0.03843816
     0.03887339]]

  [[-0.07880677 -0.081793   -0.08021648 ... -0.07915598 -0.08803446
    -0.07824855]
   [-0.07975532 -0.0827108  -0.08153103 ... -0.08920732 -0.08202183
    -0.07717112]
   [-0.080375   -0.08304221 -0.07943083 ... -0.08371484 -0.07717931
    -0.07678773]
   ...
   [-0.07925861 -0.07035945 -0.07607639 ... -0.08380341 -0.08219168
    -0.08388805]
   [-0.07702561 -0.07861231 -0.08642116 ... -0.08342467 -0.07647635
    -0.08471077]
   [-0.08218312 -0.08206419 -0.0820056  ... -0.0710914  -0.08050337
    -0.08665174]]

  [[-0.05770465 -0.05971097 -0.06042907 ... -0.05981689 -0.05351909
    -0.06045758]
   [-0.06280089 -0.06072729 -0.06125656 ... -0.06167236 -0.05607811
    -0.06504007]
   [-0.05974955 -0.06224146 -0.05134789 ... -0.06194806 -0.05703649
    -0.05661972]
   ...
   [-0.0587245  -0.06006888 -0.06369887 ... -0.0509633  -0.05987025
    -0.05689852]
   [-0.05888586 -0.06178844 -0.06245932 ... -0.06076533 -0.05802548
    -0.06169396]
   [-0.05935856 -0.05726556 -0.05836396 ... -0.06468105 -0.05601557
    -0.05411654]]

  ...

  [[-0.06885257 -0.06496602 -0.07227325 ... -0.06768468 -0.07973982
    -0.06684067]
   [-0.06921483 -0.07310341 -0.07145415 ... -0.07373261 -0.06769554
    -0.06564213]
   [-0.07235637 -0.08390911 -0.06977317 ... -0.06690352 -0.06286541
    -0.06959118]
   ...
   [-0.07000594 -0.06508094 -0.06877656 ... -0.07407243 -0.07690564
    -0.06396648]
   [-0.07082649 -0.07268029 -0.07315704 ... -0.06758922 -0.06662212
    -0.06855071]
   [-0.07199463 -0.06999994 -0.07014568 ... -0.06523817 -0.07094447
    -0.07466151]]

  [[ 0.06986891  0.06941634  0.06439675 ...  0.06187101  0.0675493
     0.07306495]
   [ 0.07319107  0.07266034  0.05997508 ...  0.06689761  0.06815154
     0.0660945 ]
   [ 0.07065641  0.05923657  0.06411441 ...  0.06652149  0.07088953
     0.07194202]
   ...
   [ 0.06848673  0.07591817  0.0726023  ...  0.06602401  0.06890585
     0.07259338]
   [ 0.07433689  0.06679939  0.06691605 ...  0.0667197   0.07184143
     0.06983658]
   [ 0.06431621  0.07212089  0.06723586 ...  0.06868842  0.07140361
     0.06901537]]

  [[ 0.02097934  0.01210438  0.01431934 ...  0.01505992  0.01852277
     0.01381299]
   [ 0.02296696  0.0177606   0.01976403 ...  0.02147305  0.02210259
     0.02313221]
   [ 0.02124698  0.02709681  0.02910981 ...  0.01016832  0.02212639
     0.01957588]
   ...
   [ 0.02289374  0.01311012  0.01578637 ...  0.01931083  0.01555186
     0.0208313 ]
   [ 0.01390727  0.02096656  0.01745579 ...  0.01781181  0.02211875
     0.01568411]
   [ 0.02439262  0.01495296  0.01968778 ...  0.02193322  0.01783368
     0.0176824 ]]]


 [[[ 0.04223164  0.03378314  0.03601065 ...  0.03959855  0.03485664
     0.04071919]
   [ 0.0318558   0.05363872  0.03783617 ...  0.04385335  0.0496259
     0.03691863]
   [ 0.03818982  0.03180957  0.04072122 ...  0.03430039  0.03384047
     0.03837577]
   ...
   [ 0.0398777   0.03721025  0.03533046 ...  0.04020133  0.03928016
     0.04710523]
   [ 0.04118172  0.03496882  0.03100736 ...  0.03642647  0.03914004
     0.0371574 ]
   [ 0.04037436  0.04040184  0.04165599 ...  0.04403537  0.03254044
     0.04335065]]

  [[-0.08552387 -0.07319534 -0.08021338 ... -0.07858572 -0.07166487
    -0.08406518]
   [-0.07923919 -0.08566054 -0.08015955 ... -0.08471547 -0.0847266
    -0.08085599]
   [-0.08489675 -0.09258271 -0.08831957 ... -0.09042192 -0.08426952
    -0.0808774 ]
   ...
   [-0.08036023 -0.07413588 -0.07989521 ... -0.07935498 -0.08571334
    -0.08329107]
   [-0.07644836 -0.07608277 -0.08767064 ... -0.08434241 -0.08071237
    -0.0839122 ]
   [-0.07979399 -0.08087463 -0.08673595 ... -0.08414597 -0.08045428
    -0.07299927]]

  [[-0.06060546 -0.05453672 -0.06102112 ... -0.05194974 -0.0567053
    -0.06273571]
   [-0.06276039 -0.05693425 -0.04725159 ... -0.06214722 -0.06443968
    -0.05762123]
   [-0.05252658 -0.06019294 -0.06137866 ... -0.04910715 -0.06131132
    -0.06036767]
   ...
   [-0.06173272 -0.05464447 -0.05099018 ... -0.06136036 -0.06400239
    -0.06106843]
   [-0.05803053 -0.05994222 -0.06404369 ... -0.04949801 -0.05738675
    -0.06158596]
   [-0.05899998 -0.06198164 -0.05937162 ... -0.06379396 -0.06430338
    -0.06287489]]

  ...

  [[-0.07173873 -0.0707745  -0.06975999 ... -0.07155637 -0.06534318
    -0.07189398]
   [-0.06730746 -0.07013785 -0.06751848 ... -0.07264671 -0.07705939
    -0.07342067]
   [-0.07058413 -0.07025788 -0.06871852 ... -0.06887744 -0.06563742
    -0.07028291]
   ...
   [-0.07809374 -0.06778216 -0.06392691 ... -0.06867532 -0.07118014
    -0.07647338]
   [-0.07219965 -0.07040192 -0.0732589  ... -0.07633238 -0.0752567
    -0.0702922 ]
   [-0.06984755 -0.07723872 -0.06846898 ... -0.06786713 -0.06702175
    -0.07062964]]

  [[ 0.06747851  0.06883495  0.06797507 ...  0.06853593  0.06575806
     0.06841848]
   [ 0.06514458  0.06994057  0.06866109 ...  0.06339982  0.06309478
     0.06588745]
   [ 0.06701669  0.0691862   0.06725767 ...  0.06696404  0.07045414
     0.07060774]
   ...
   [ 0.07085218  0.0809648   0.06841429 ...  0.06838602  0.06918488
     0.07014886]
   [ 0.07304276  0.07134987  0.07214254 ...  0.07656243  0.07136226
     0.06578355]
   [ 0.06968872  0.07193028  0.06518821 ...  0.07004035  0.06891351
     0.06959624]]

  [[ 0.02295301  0.01347421  0.02212771 ...  0.02214386  0.01323562
     0.02334489]
   [ 0.01578862  0.01825874  0.01307945 ...  0.0216907   0.02719616
     0.02306023]
   [ 0.01491401  0.01406     0.02918804 ...  0.02165697  0.01733657
     0.01930147]
   ...
   [ 0.01614039  0.019646    0.02148937 ...  0.00664111  0.01888491
     0.02413018]
   [ 0.01757734  0.01567486  0.01912338 ...  0.02099028  0.01717271
     0.01547725]
   [ 0.01756483  0.020161    0.01650484 ...  0.01933268  0.0167334
     0.01855144]]]]
x.shape
(2, 20, 32, 32)
#unet_model.init_conv(x)  ####调用实例化后类的方法!!!!! 好像是失败的
import numpy as np
from mindspore import Tensor

# 定义时间步的起始值、步数以及步长(默认为1,即每个时间步增加1)
start = 0
num_steps = 10
step = 1

# 生成线性递增的时间步长序列
t = Tensor(np.arange(start, start + num_steps * step, step), dtype=ms.int32)

print("线性递增的时间步长序列:", t)


# time_mlp
# SequentialCell<
#   (0): SinusoidalPositionEmbeddings<>
#   (1): Dense<input_channels=32, output_channels=128, has_bias=True>
#   (2): GELU<>
#   (3): Dense<input_channels=128, output_channels=128, has_bias=True>
#   >
time_mlp(t) ###这里正确了
线性递增的时间步长序列: [0 1 2 3 4 5 6 7 8 9]

Tensor(shape=[10, 128], dtype=Float32, value=
[[ 1.89717814e-01,  1.14008449e-02,  3.33061777e-02 ...  1.43985003e-01,  3.92933972e-02, -1.06829256e-02],
 [ 1.93240538e-01,  1.78442001e-02,  6.77158684e-02 ...  1.36301309e-01,  7.64560923e-02, -1.50307640e-02],
 [ 1.83035284e-01,  2.44393535e-02,  8.70461762e-02 ...  1.38745904e-01,  1.33171901e-01, -3.85175534e-02],
 ...
 [ 1.28773689e-01,  1.91335917e-01, -9.48226005e-02 ...  8.54851380e-02,  1.52098373e-01,  2.03581899e-02],
 [ 1.22549936e-01,  1.48201510e-01, -8.17623138e-02 ...  4.44053262e-02,  9.75183249e-02,  3.97774130e-02],
 [ 1.07752994e-01,  1.08763084e-01, -7.05250949e-02 ...  4.34711799e-02,  6.16942756e-02,  1.67786255e-02]])
import numpy as np
from mindspore import Tensor

# 定义时间步的起始值、步数以及步长(默认为1,即每个时间步增加1)
start = 0
num_steps = 5
step = 1

# 生成线性递增的时间步长序列
t = Tensor(np.arange(start, start + num_steps * step, step), dtype=ms.int32)

print("线性递增的时间步长序列:", t)


# time_mlp
# SequentialCell<
#   (0): SinusoidalPositionEmbeddings<>
#   (1): Dense<input_channels=32, output_channels=128, has_bias=True>
#   (2): GELU<>
#   (3): Dense<input_channels=128, output_channels=128, has_bias=True>
#   >
time_mlp(t) ###这里正确了
线性递增的时间步长序列: [0 1 2 3 4]

Tensor(shape=[5, 128], dtype=Float32, value=
[[ 1.89717814e-01,  1.14008449e-02,  3.33061777e-02 ...  1.43985003e-01,  3.92933972e-02, -1.06829256e-02],
 [ 1.93240538e-01,  1.78442001e-02,  6.77158684e-02 ...  1.36301309e-01,  7.64560923e-02, -1.50307640e-02],
 [ 1.83035284e-01,  2.44393535e-02,  8.70461762e-02 ...  1.38745904e-01,  1.33171901e-01, -3.85175534e-02],
 [ 1.60673216e-01,  5.66724911e-02,  6.34887069e-02 ...  1.58708930e-01,  1.91956162e-01, -7.78784081e-02],
 [ 1.33557051e-01,  1.20526701e-01,  5.27272746e-03 ...  1.78237736e-01,  2.34055102e-01, -9.55123529e-02]])
t.shape
(5,)

看起来你希望构造一个形状为(5,)的张量t,这通常表示一个包含5个元素的一维向量。在Python中使用MindSpore库,你可以很容易地创建这样一个张量。这里有几个创建此形状张量的例子,包括初始化为特定值或随机值:

初始化为零

import mindspore as ms

# 创建一个全零的张量,形状为(5,)
t = ms.Tensor.zeros(5, dtype=ms.float32)
print(t.shape)  # 输出: (5,)

初始化为一

t = ms.Tensor.ones(5, dtype=ms.float32)
print(t.shape)  # 输出: (5,)

初始化为单位序列(例如,[0, 1, 2, 3, 4])

import numpy as np

t = ms.Tensor(np.arange(5), dtype=ms.float32)
print(t.shape)  # 输出: (5,)

初始化为随机值

t = ms.Tensor(np.random.rand(5), dtype=ms.float32)
print(t.shape)  # 输出: (5,)

使用特定值填充

value = 3.14  # 例如,使用π作为填充值
t = ms.Tensor(np.full(5, value), dtype=ms.float32)
print(t.shape)  # 输出: (5,)

以上任一代码块都可以创建一个形状为(5,)的张量t,根据你的具体需求选择合适的初始化方式。

h = []
downs # 3#套有3个cellist元素的一个celllist
CellList<
  (0): CellList<
    (0): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=20, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c130>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=20>
        (1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc070>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=64>
        (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354a90>, bias_init=None, format=NCHW>
        >
      (res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3d7cd0>, bias_init=None, format=NCHW>
      >
    (1): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=32, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3c17f0>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=32>
        (1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3062b0>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=64>
        (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c190>, bias_init=None, format=NCHW>
        >
      (res_conv): Identity<>
      >
    (2): Residual<
      (fn): PreNorm<
        (fn): LinearAttention<
          (to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d90>, bias_init=None, format=NCHW>
          (to_out): SequentialCell<
            (0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d00>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc306a60>, format=NCHW>
            (1): LayerNorm<>
            >
          >
        (norm): GroupNorm<num_groups=1, num_channels=32>
        >
      >
    (3): Conv2d<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf1f0>, bias_init=None, format=NCHW>
    >
  (1): CellList<
    (0): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=32, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306fd0>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=32>
        (1): Conv2d<input_channels=32, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf580>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=128>
        (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf190>, bias_init=None, format=NCHW>
        >
      (res_conv): Conv2d<input_channels=32, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf670>, bias_init=None, format=NCHW>
      >
    (1): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=64, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf940>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=64>
        (1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf910>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=128>
        (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf760>, bias_init=None, format=NCHW>
        >
      (res_conv): Identity<>
      >
    (2): Residual<
      (fn): PreNorm<
        (fn): LinearAttention<
          (to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfa30>, bias_init=None, format=NCHW>
          (to_out): SequentialCell<
            (0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfac0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2cfb80>, format=NCHW>
            (1): LayerNorm<>
            >
          >
        (norm): GroupNorm<num_groups=1, num_channels=64>
        >
      >
    (3): Conv2d<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfdc0>, bias_init=None, format=NCHW>
    >
  (2): CellList<
    (0): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=64, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfe50>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=64>
        (1): Conv2d<input_channels=64, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cff40>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=256>
        (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac130>, bias_init=None, format=NCHW>
        >
      (res_conv): Conv2d<input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac250>, bias_init=None, format=NCHW>
      >
    (1): ConvNextBlock<
      (mlp): SequentialCell<
        (0): GELU<>
        (1): Dense<input_channels=128, output_channels=128, has_bias=True>
        >
      (ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac1f0>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=128>
        (1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac370>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=256>
        (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac4c0>, bias_init=None, format=NCHW>
        >
      (res_conv): Identity<>
      >
    (2): Residual<
      (fn): PreNorm<
        (fn): LinearAttention<
          (to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac6a0>, bias_init=None, format=NCHW>
          (to_out): SequentialCell<
            (0): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac7c0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2ac820>, format=NCHW>
            (1): LayerNorm<>
            >
          >
        (norm): GroupNorm<num_groups=1, num_channels=128>
        >
      >
    (3): Identity<>
    >
  >
for downsample in downs:
    print(downsample)
CellList<
  (0): ConvNextBlock<
    (mlp): SequentialCell<
      (0): GELU<>
      (1): Dense<input_channels=128, output_channels=20, has_bias=True>
      >
    (ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c130>, bias_init=None, format=NCHW>
    (net): SequentialCell<
      (0): GroupNorm<num_groups=1, num_channels=20>
      (1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc070>, bias_init=None, format=NCHW>
      (2): GELU<>
      (3): GroupNorm<num_groups=1, num_channels=64>
      (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354a90>, bias_init=None, format=NCHW>
      >
    (res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3d7cd0>, bias_init=None, format=NCHW>
    >
  (1): ConvNextBlock<
    (mlp): SequentialCell<
      (0): GELU<>
      (1): Dense<input_channels=128, output_channels=32, has_bias=True>
      >
    (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3c17f0>, bias_init=None, format=NCHW>
    (net): SequentialCell<
      (0): GroupNorm<num_groups=1, num_channels=32>
      (1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3062b0>, bias_init=None, format=NCHW>
      (2): GELU<>
      (3): GroupNorm<num_groups=1, num_channels=64>
      (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c190>, bias_init=None, format=NCHW>
      >
    (res_conv): Identity<>
    >
  (2): Residual<
    (fn): PreNorm<
      (fn): LinearAttention<
        (to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d90>, bias_init=None, format=NCHW>
        (to_out): SequentialCell<
          (0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d00>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc306a60>, format=NCHW>
          (1): LayerNorm<>
          >
        >
      (norm): GroupNorm<num_groups=1, num_channels=32>
      >
    >
  (3): Conv2d<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf1f0>, bias_init=None, format=NCHW>
  >
CellList<
  (0): ConvNextBlock<
    (mlp): SequentialCell<
      (0): GELU<>
      (1): Dense<input_channels=128, output_channels=32, has_bias=True>
      >
    (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306fd0>, bias_init=None, format=NCHW>
    (net): SequentialCell<
      (0): GroupNorm<num_groups=1, num_channels=32>
      (1): Conv2d<input_channels=32, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf580>, bias_init=None, format=NCHW>
      (2): GELU<>
      (3): GroupNorm<num_groups=1, num_channels=128>
      (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf190>, bias_init=None, format=NCHW>
      >
    (res_conv): Conv2d<input_channels=32, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf670>, bias_init=None, format=NCHW>
    >
  (1): ConvNextBlock<
    (mlp): SequentialCell<
      (0): GELU<>
      (1): Dense<input_channels=128, output_channels=64, has_bias=True>
      >
    (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf940>, bias_init=None, format=NCHW>
    (net): SequentialCell<
      (0): GroupNorm<num_groups=1, num_channels=64>
      (1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf910>, bias_init=None, format=NCHW>
      (2): GELU<>
      (3): GroupNorm<num_groups=1, num_channels=128>
      (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf760>, bias_init=None, format=NCHW>
      >
    (res_conv): Identity<>
    >
  (2): Residual<
    (fn): PreNorm<
      (fn): LinearAttention<
        (to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfa30>, bias_init=None, format=NCHW>
        (to_out): SequentialCell<
          (0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfac0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2cfb80>, format=NCHW>
          (1): LayerNorm<>
          >
        >
      (norm): GroupNorm<num_groups=1, num_channels=64>
      >
    >
  (3): Conv2d<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfdc0>, bias_init=None, format=NCHW>
  >
CellList<
  (0): ConvNextBlock<
    (mlp): SequentialCell<
      (0): GELU<>
      (1): Dense<input_channels=128, output_channels=64, has_bias=True>
      >
    (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfe50>, bias_init=None, format=NCHW>
    (net): SequentialCell<
      (0): GroupNorm<num_groups=1, num_channels=64>
      (1): Conv2d<input_channels=64, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cff40>, bias_init=None, format=NCHW>
      (2): GELU<>
      (3): GroupNorm<num_groups=1, num_channels=256>
      (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac130>, bias_init=None, format=NCHW>
      >
    (res_conv): Conv2d<input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac250>, bias_init=None, format=NCHW>
    >
  (1): ConvNextBlock<
    (mlp): SequentialCell<
      (0): GELU<>
      (1): Dense<input_channels=128, output_channels=128, has_bias=True>
      >
    (ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac1f0>, bias_init=None, format=NCHW>
    (net): SequentialCell<
      (0): GroupNorm<num_groups=1, num_channels=128>
      (1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac370>, bias_init=None, format=NCHW>
      (2): GELU<>
      (3): GroupNorm<num_groups=1, num_channels=256>
      (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac4c0>, bias_init=None, format=NCHW>
      >
    (res_conv): Identity<>
    >
  (2): Residual<
    (fn): PreNorm<
      (fn): LinearAttention<
        (to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac6a0>, bias_init=None, format=NCHW>
        (to_out): SequentialCell<
          (0): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac7c0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2ac820>, format=NCHW>
          (1): LayerNorm<>
          >
        >
      (norm): GroupNorm<num_groups=1, num_channels=128>
      >
    >
  (3): Identity<>
  >
for block1, block2, attn, downsample in downs:
    print("aaaaaaaaaas11BL1")
    print(block1)
    print("aaaaaaaaaasBL2")
    print( block2)
    print("aaaaaaaaaasAT 残差 attn")
    print(attn)
    print("aaaaaaaaaasAT 下采样")
    print( downsample)
    
aaaaaaaaaas11BL1
ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=20, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c130>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=20>
    (1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc070>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=64>
    (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354a90>, bias_init=None, format=NCHW>
    >
  (res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3d7cd0>, bias_init=None, format=NCHW>
  >
aaaaaaaaaasBL2
ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=32, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3c17f0>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=32>
    (1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3062b0>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=64>
    (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c190>, bias_init=None, format=NCHW>
    >
  (res_conv): Identity<>
  >
aaaaaaaaaasAT 残差 attn
Residual<
  (fn): PreNorm<
    (fn): LinearAttention<
      (to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d90>, bias_init=None, format=NCHW>
      (to_out): SequentialCell<
        (0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d00>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc306a60>, format=NCHW>
        (1): LayerNorm<>
        >
      >
    (norm): GroupNorm<num_groups=1, num_channels=32>
    >
  >
aaaaaaaaaasAT 下采样
Conv2d<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf1f0>, bias_init=None, format=NCHW>
aaaaaaaaaas11BL1
ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=32, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306fd0>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=32>
    (1): Conv2d<input_channels=32, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf580>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=128>
    (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf190>, bias_init=None, format=NCHW>
    >
  (res_conv): Conv2d<input_channels=32, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf670>, bias_init=None, format=NCHW>
  >
aaaaaaaaaasBL2
ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=64, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf940>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=64>
    (1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf910>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=128>
    (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf760>, bias_init=None, format=NCHW>
    >
  (res_conv): Identity<>
  >
aaaaaaaaaasAT 残差 attn
Residual<
  (fn): PreNorm<
    (fn): LinearAttention<
      (to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfa30>, bias_init=None, format=NCHW>
      (to_out): SequentialCell<
        (0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfac0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2cfb80>, format=NCHW>
        (1): LayerNorm<>
        >
      >
    (norm): GroupNorm<num_groups=1, num_channels=64>
    >
  >
aaaaaaaaaasAT 下采样
Conv2d<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfdc0>, bias_init=None, format=NCHW>
aaaaaaaaaas11BL1
ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=64, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfe50>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=64>
    (1): Conv2d<input_channels=64, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cff40>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=256>
    (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac130>, bias_init=None, format=NCHW>
    >
  (res_conv): Conv2d<input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac250>, bias_init=None, format=NCHW>
  >
aaaaaaaaaasBL2
ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=128, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac1f0>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=128>
    (1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac370>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=256>
    (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac4c0>, bias_init=None, format=NCHW>
    >
  (res_conv): Identity<>
  >
aaaaaaaaaasAT 残差 attn
Residual<
  (fn): PreNorm<
    (fn): LinearAttention<
      (to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac6a0>, bias_init=None, format=NCHW>
      (to_out): SequentialCell<
        (0): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac7c0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2ac820>, format=NCHW>
        (1): LayerNorm<>
        >
      >
    (norm): GroupNorm<num_groups=1, num_channels=128>
    >
  >
aaaaaaaaaasAT 下采样
Identity<>

dim_in, dim_out,dim
(32, 64, 32)

因为循环这个[(20, 32), (32, 64), (64, 128)] 所以down 有3个元素 nn.cell

convnext_mult
2
#dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
dims
[20, 32, 64, 128]
i=0
for block1, block2, attn, downsample in downs:
    i=i+1
    print("--------",i)
    print("BL1块1:",block1,"BL2块2:", block2, "ATTT残差注意力:",attn, "DOWN下采样:",downsample)
    
#                 block_klass(dim_in, dim_out, time_emb_dim=time_dim),  #time_dim=128 dim=128
#                 block_klass(dim_out, dim_out, time_emb_dim=time_dim),
#                 Residual(PreNorm(dim_out, LinearAttention(dim_out))),
#                 Downsample(dim_out) if not is_last else nn.Identity(),    

block_klass = partial(ConvNextBlock, mult=convnext_mult) ##传入ConvNextBlock的第一个参数mult=convnext_mult
block_klass


# class ConvNextBlock(nn.Cell):
#     def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
#         super().__init__()
#         self.mlp = (
#             nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
#             if exists(time_emb_dim)
#             else None
#         )

#         self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
#         self.net = nn.SequentialCell(
#             nn.GroupNorm(1, dim) if norm else nn.Identity(),
#             nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
#             nn.GELU(),
#             nn.GroupNorm(1, dim_out * mult),
#             nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
#         )

#         self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

#     def construct(self, x, time_emb=None):
#         h = self.ds_conv(x)
#         if exists(self.mlp) and exists(time_emb):
#             assert exists(time_emb), "time embedding must be passed in"
#             condition = self.mlp(time_emb)
#             condition = condition.expand_dims(-1).expand_dims(-1)
#             h = h + condition

#         h = self.net(h)
#         return h + self.res_conv(x)


### 第一层BL1
#input_channels=20, output_channels=32  (20,32)

### 第一层BL2
#Conv2d<input_channels=64【32*2】, output_channels=32, (res_conv): Identity<>  (32,32)

### 第一层Res
#Conv2d<input_channels=128, output_channels=32,  (32,32)

### 第一层down
#Conv2d<input_channels=32, output_channels=32  (32,32)

### 第二层BL1
#Conv2d<input_channels=32, output_channels=64  (32,64)

### 第二层BL2
# Conv2d<input_channels=128【64*2】, output_channels=64,  (res_conv): Identity<>  (64,64)

### 第二层Res
#Conv2d<input_channels=128, output_channels=64 (64,64)

### 第二层down
# Conv2d<input_channels=64, output_channels=64  (64,64)


### 第三层BL1
#Conv2d<input_channels=64, output_channels=128 (64,128)

### 第三层BL2
#Conv2d<input_channels=256【128*2】, output_channels=128, (res_conv): Identity<>  (128,128)

### 第三层Res
#Conv2d<input_channels=128, output_channels=128, (res_conv): Identity<> (128,128)

### 第三层down
#  Identity<> (128,128)
-------- 1
BL1块1: ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=20, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c130>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=20>
    (1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc070>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=64>
    (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354a90>, bias_init=None, format=NCHW>
    >
  (res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3d7cd0>, bias_init=None, format=NCHW>
  > BL2块2: ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=32, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3c17f0>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=32>
    (1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3062b0>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=64>
    (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c190>, bias_init=None, format=NCHW>
    >
  (res_conv): Identity<>
  > ATTT残差注意力: Residual<
  (fn): PreNorm<
    (fn): LinearAttention<
      (to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d90>, bias_init=None, format=NCHW>
      (to_out): SequentialCell<
        (0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d00>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc306a60>, format=NCHW>
        (1): LayerNorm<>
        >
      >
    (norm): GroupNorm<num_groups=1, num_channels=32>
    >
  > DOWN下采样: Conv2d<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf1f0>, bias_init=None, format=NCHW>
-------- 2
BL1块1: ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=32, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306fd0>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=32>
    (1): Conv2d<input_channels=32, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf580>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=128>
    (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf190>, bias_init=None, format=NCHW>
    >
  (res_conv): Conv2d<input_channels=32, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf670>, bias_init=None, format=NCHW>
  > BL2块2: ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=64, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf940>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=64>
    (1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf910>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=128>
    (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf760>, bias_init=None, format=NCHW>
    >
  (res_conv): Identity<>
  > ATTT残差注意力: Residual<
  (fn): PreNorm<
    (fn): LinearAttention<
      (to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfa30>, bias_init=None, format=NCHW>
      (to_out): SequentialCell<
        (0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfac0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2cfb80>, format=NCHW>
        (1): LayerNorm<>
        >
      >
    (norm): GroupNorm<num_groups=1, num_channels=64>
    >
  > DOWN下采样: Conv2d<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfdc0>, bias_init=None, format=NCHW>
-------- 3
BL1块1: ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=64, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfe50>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=64>
    (1): Conv2d<input_channels=64, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cff40>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=256>
    (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac130>, bias_init=None, format=NCHW>
    >
  (res_conv): Conv2d<input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac250>, bias_init=None, format=NCHW>
  > BL2块2: ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=128, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac1f0>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=128>
    (1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac370>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=256>
    (4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac4c0>, bias_init=None, format=NCHW>
    >
  (res_conv): Identity<>
  > ATTT残差注意力: Residual<
  (fn): PreNorm<
    (fn): LinearAttention<
      (to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac6a0>, bias_init=None, format=NCHW>
      (to_out): SequentialCell<
        (0): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac7c0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2ac820>, format=NCHW>
        (1): LayerNorm<>
        >
      >
    (norm): GroupNorm<num_groups=1, num_channels=128>
    >
  > DOWN下采样: Identity<>

functools.partial(<class '__main__.ConvNextBlock'>, mult=2)
x.shape
(2, 20, 32, 32)
t
Tensor(shape=[5], dtype=Int32, value= [0, 1, 2, 3, 4])
t=time_mlp(t) ###这里正确了
t
Tensor(shape=[5, 128], dtype=Float32, value=
[[ 1.89717814e-01,  1.14008449e-02,  3.33061777e-02 ...  1.43985003e-01,  3.92933972e-02, -1.06829256e-02],
 [ 1.93240538e-01,  1.78442001e-02,  6.77158684e-02 ...  1.36301309e-01,  7.64560923e-02, -1.50307640e-02],
 [ 1.83035284e-01,  2.44393535e-02,  8.70461762e-02 ...  1.38745904e-01,  1.33171901e-01, -3.85175534e-02],
 [ 1.60673216e-01,  5.66724911e-02,  6.34887069e-02 ...  1.58708930e-01,  1.91956162e-01, -7.78784081e-02],
 [ 1.33557051e-01,  1.20526701e-01,  5.27272746e-03 ...  1.78237736e-01,  2.34055102e-01, -9.55123529e-02]])
# 选取第一行
# 选取第一行
new_t = t[0:1, :]  # 或者直接 t[0:1] 也可以 = t[0:1, :]  # 或者直接 t[0:1] 也可以
new_t.shape
(1, 128)
t=new_t


class ConvNextBlock(nn.Cell):
    def __init__(self, dim=20, dim_out=32, *, time_emb_dim=128, mult=2, norm=True):
        super().__init__()
        self.mlp = (
            nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
        self.net = nn.SequentialCell(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
        )

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def construct(self, x, time_emb=None):
        h = self.ds_conv(x)
        if exists(self.mlp) and exists(time_emb):
            assert exists(time_emb), "time embedding must be passed in"
            condition = self.mlp(time_emb)
            condition = condition.expand_dims(-1).expand_dims(-1)
            h = h + condition

        h = self.net(h)
        return h + self.res_conv(x)
BL1=ConvNextBlock()
BL1
ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=20, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffddded2b80>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=20>
    (1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffddded2a00>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=64>
    (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffddded2f70>, bias_init=None, format=NCHW>
    >
  (res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffdde372970>, bias_init=None, format=NCHW>
  >
x = ms.Tensor(shape=(batch_size, channels, image_side_length, image_side_length), dtype=ms.float32, init=Normal())
x.shape  # 显示数据形状
x = init_conv(x)
t=new_t
x.shape,t.shape
((2, 20, 32, 32), (1, 128))
BL1(x, t)
-


Tensor(shape=[2, 32, 32, 32], dtype=Float32, value=
[[[[ 4.79678512e-01,  6.18093789e-01,  4.11911160e-01 ...  2.45554969e-01,  2.87007272e-01,  5.84303178e-02],
   [ 3.35714668e-01,  5.81144929e-01,  3.02089810e-01 ...  7.34893382e-02, -2.16056317e-01, -2.40196183e-01],
   [ 6.84010506e-01,  1.10967433e+00,  7.40820885e-01 ...  4.42677915e-01, -5.60586452e-02, -1.65627971e-01],
   ...
   [ 7.63499856e-01,  1.13718486e+00,  9.87868309e-01 ...  6.83883190e-01,  1.12375900e-01, -3.52420285e-02],
   [ 7.06175983e-01,  1.14337587e+00,  9.23162937e-01 ...  6.31385088e-01,  1.49176538e-01, -3.98113094e-02],
   [ 1.76896825e-01,  4.82230633e-01,  4.89957243e-01 ...  4.00457889e-01, -2.55417023e-02,  1.21514685e-01]],
  [[ 2.81717628e-01, -9.01857391e-04, -8.76490697e-02 ... -7.88121223e-02, -1.26668364e-01, -1.73759460e-01],
   [ 1.92831263e-01, -1.43926665e-01, -1.48099199e-01 ... -2.68793881e-01, -1.00307748e-01, -1.20211102e-01],
   [ 1.39493421e-01,  2.17211112e-01,  1.45897210e-01 ...  1.56540543e-01,  1.65525198e-01,  5.83062395e-02],
   ...
   [ 6.92536831e-02,  5.68503514e-02,  3.68858390e-02 ...  8.93135369e-02,  1.07637540e-01,  4.47027944e-02],
   [ 1.12979397e-01,  2.12710798e-01,  5.37276417e-02 ...  1.01731792e-01, -4.49074507e-02, -2.13617496e-02],
   [ 7.63944685e-02,  1.35763273e-01,  1.16834700e-01 ...  1.85339376e-01,  1.34029865e-01,  2.19782546e-01]],
  [[ 1.30130202e-01,  3.29991281e-01,  4.25871283e-01 ...  3.08504313e-01,  4.61269379e-01,  1.57225683e-01],
   [ 4.05929983e-01,  7.34413862e-01,  8.77515614e-01 ...  7.71579146e-01,  9.44401443e-01,  5.34572124e-01],
   [-3.62766758e-02,  3.63330275e-01,  4.08758491e-01 ...  3.54813248e-01,  5.34208298e-01,  2.87856668e-01],
   ...
   [-1.26003683e-01,  2.56464094e-01,  3.78679991e-01 ...  4.59082156e-01,  5.85478425e-01,  2.87693620e-01],
   [-4.88909632e-02,  2.99566031e-01,  3.99350137e-01 ...  4.61422205e-01,  4.17674005e-01,  7.26606846e-02],
   [-2.39267662e-01, -3.09484214e-01, -2.71493912e-01 ... -7.24366307e-02, -1.12498447e-01, -1.38472736e-01]],
  ...
  [[-1.59212112e-01,  7.95098245e-02, -1.85586754e-02 ... -2.23550811e-01, -2.70033002e-01, -2.44036630e-01],
   [-1.60980105e-01,  4.35432374e-01,  5.99099815e-01 ...  4.49325353e-01,  4.23938036e-01,  3.12254220e-01],
   [-3.05827290e-01,  7.60348588e-02,  2.39996284e-01 ...  2.72248350e-02, -1.65684037e-02, -1.06293596e-01],
   ...
   [-2.98552692e-01,  5.39370701e-02,  2.43080735e-01 ...  1.28992423e-01,  6.57526404e-02, -7.48077184e-02],
   [-2.99177408e-01, -1.83835357e-01, -3.95203456e-02 ... -4.43453342e-02, -1.39597371e-01, -2.18513533e-01],
   [ 1.12659251e-02,  2.71181725e-02,  8.25900063e-02 ...  1.92151055e-01,  2.09751755e-01,  4.28373404e-02]],
  [[-1.12641305e-01,  2.13623658e-01,  7.53313601e-02 ... -1.21324155e-02, -1.53158829e-01, -4.77597833e-01],
   [-3.84328097e-01,  6.15204051e-02,  1.25263743e-02 ... -7.10279793e-02, -2.77535737e-01, -3.76104653e-01],
   [-4.07613993e-01,  1.12187594e-01,  3.72527242e-02 ... -2.06387043e-02, -1.46990225e-01, -2.87585199e-01],
   ...
   [-3.61820847e-01,  9.99522805e-02, -9.61808674e-03 ... -4.89163473e-02, -1.65467933e-01, -3.17837149e-01],
   [-6.31257832e-01, -2.93515027e-01, -3.12220454e-01 ... -1.29600003e-01, -1.40924498e-01, -1.52635127e-01],
   [-5.44437885e-01, -2.92856932e-01, -2.64693975e-01 ... -1.66876107e-01, -1.02364674e-01,  1.51942633e-02]],
  [[-4.31133687e-01, -6.47886336e-01, -8.31129193e-01 ... -8.14953923e-01, -8.53494108e-01, -5.33654928e-01],
   [-1.00464332e+00, -1.09428477e+00, -1.38921976e+00 ... -1.53568864e+00, -1.31930661e+00, -4.56116676e-01],
   [-1.13990593e+00, -1.03481460e+00, -1.49022770e+00 ... -1.62232399e+00, -1.40410137e+00, -5.24185598e-01],
   ...
   [-1.26637757e+00, -1.16348124e+00, -1.51403701e+00 ... -1.62057757e+00, -1.49686539e+00, -5.62209725e-01],
   [-1.11411476e+00, -9.28978443e-01, -1.17068827e+00 ... -1.24776149e+00, -1.03560197e+00, -3.00330132e-01],
   [-8.52251232e-01, -7.45468676e-01, -9.34467912e-01 ... -1.00492966e+00, -8.28293085e-01, -2.36020580e-01]]],
 [[[ 4.66426373e-01,  6.25464499e-01,  3.93321365e-01 ...  2.30096430e-01,  3.03178400e-01,  5.51086180e-02],
   [ 3.32810491e-01,  6.17641032e-01,  3.06311995e-01 ...  1.02875866e-01, -1.90033853e-01, -2.67078996e-01],
   [ 6.84333742e-01,  1.09874713e+00,  7.69674480e-01 ...  4.62564647e-01, -6.67065307e-02, -2.36097589e-01],
   ...
   [ 7.33127177e-01,  1.17721725e+00,  9.89053011e-01 ...  7.15648472e-01,  1.17136240e-01, -4.71793935e-02],
   [ 6.87817514e-01,  1.09633350e+00,  8.85757685e-01 ...  5.86604059e-01,  9.45525989e-02, -3.36224921e-02],
   [ 1.95047542e-01,  4.68445808e-01,  4.64000225e-01 ...  3.75145793e-01,  1.80484354e-03,  1.13696203e-01]],
  [[ 2.65101492e-01,  2.46687196e-02, -1.07584536e-01 ... -1.03970490e-01, -8.17846432e-02, -1.53097644e-01],
   [ 1.52385280e-01, -8.83764774e-02, -1.62100300e-01 ... -2.18001902e-01, -9.41001922e-02, -1.19305871e-01],
   [ 1.61403462e-01,  2.30408147e-01,  1.57331020e-01 ...  1.96940184e-01,  1.30461589e-01,  6.52605519e-02],
   ...
   [ 5.33300415e-02,  1.22396260e-01, -1.88096687e-02 ...  1.05915010e-01,  1.53571054e-01,  2.45359484e-02],
   [ 9.76279452e-02,  1.82655305e-01,  1.09691672e-01 ...  1.40925452e-01,  8.01324844e-04,  1.88996084e-03],
   [ 7.78941736e-02,  1.48156703e-01,  1.39126211e-01 ...  2.34847367e-01,  1.08238310e-01,  2.09336147e-01]],
  [[ 1.17656320e-01,  3.43433738e-01,  4.39827234e-01 ...  3.09850901e-01,  4.53984410e-01,  1.49862975e-01],
   [ 3.95844698e-01,  7.22729802e-01,  8.56524229e-01 ...  8.03788304e-01,  9.29986835e-01,  5.35356879e-01],
   [-5.63383885e-02,  3.74548256e-01,  3.96855712e-01 ...  3.82491171e-01,  5.69522381e-01,  3.37262630e-01],
   ...
   [-1.52335018e-01,  2.06072912e-01,  3.62504959e-01 ...  4.47174758e-01,  5.80285132e-01,  2.62391001e-01],
   [-3.93561423e-02,  3.46813500e-01,  3.57039988e-01 ...  4.35864031e-01,  4.56840277e-01,  7.46745914e-02],
   [-2.53383636e-01, -2.85494059e-01, -2.50772089e-01 ... -1.19937055e-01, -9.26215500e-02, -1.42144099e-01]],
  ...
  [[-1.72060549e-01,  7.35142902e-02,  1.05164722e-02 ... -2.21164092e-01, -2.66059518e-01, -2.46203467e-01],
   [-1.51937515e-01,  4.78927851e-01,  5.69894075e-01 ...  4.43681806e-01,  4.60492224e-01,  2.69292653e-01],
   [-3.13104421e-01,  1.40309155e-01,  2.25837916e-01 ...  3.81425694e-02,  7.96409026e-02, -1.00592285e-01],
   ...
   [-3.12582165e-01,  3.55643854e-02,  1.99504092e-01 ...  1.73697099e-01,  5.84969185e-02, -7.21051544e-02],
   [-2.97579527e-01, -1.40844122e-01, -5.89616746e-02 ... -2.86213737e-02, -1.22039340e-01, -2.13227138e-01],
   [ 3.26608960e-03,  4.80151782e-03,  5.54511398e-02 ...  1.92409024e-01,  1.99357480e-01,  3.34331095e-02]],
  [[-1.12870112e-01,  2.09549189e-01,  9.65655223e-02 ... -4.27889600e-02, -1.44986391e-01, -4.36559677e-01],
   [-3.58288437e-01,  9.11962241e-02, -8.71371478e-04 ... -4.59572896e-02, -2.30747938e-01, -4.01585191e-01],
   [-3.93885374e-01,  1.43090501e-01, -1.07427724e-02 ... -2.74238884e-02, -1.51127338e-01, -3.17271531e-01],
   ...
   [-3.39975446e-01,  6.32377267e-02, -4.78150323e-02 ... -9.10668075e-02, -1.39780402e-01, -2.90815294e-01],
   [-6.53750360e-01, -2.34141201e-01, -2.83103675e-01 ... -9.91634205e-02, -1.61574319e-01, -1.63588241e-01],
   [-5.44618607e-01, -2.84289837e-01, -2.75803030e-01 ... -1.71445966e-01, -1.29518926e-01,  9.64298844e-04]],
  [[-4.26712006e-01, -6.10979497e-01, -8.42772007e-01 ... -8.18627119e-01, -8.19367886e-01, -5.47874212e-01],
   [-1.00363958e+00, -1.09864676e+00, -1.45736146e+00 ... -1.57554007e+00, -1.27784061e+00, -5.02054691e-01],
   [-1.13520634e+00, -1.06120992e+00, -1.46519005e+00 ... -1.58347833e+00, -1.37640107e+00, -5.30683100e-01],
   ...
   [-1.25580835e+00, -1.18095815e+00, -1.52926147e+00 ... -1.63596940e+00, -1.44870484e+00, -5.40451765e-01],
   [-1.12258959e+00, -9.22193348e-01, -1.14803100e+00 ... -1.24472821e+00, -1.02205288e+00, -3.08310509e-01],
   [-8.60808074e-01, -7.58719683e-01, -9.61336493e-01 ... -9.92724419e-01, -8.70862603e-01, -2.35298872e-01]]]])
x = ms.Tensor(shape=(batch_size, channels, image_side_length, image_side_length), dtype=ms.float32, init=Normal())
x.shape  # 显示数据形状
x = init_conv(x)
t=new_t
x.shape,t.shape  ##需要放到一个格子里面才能运算成功
for block1, block2, attn, downsample in downs:
    ###事实循环3次,每次有这四个变量
    x = block1(x, t)
    x = block2(x, t)
    x = attn(x)
    h.append(x)

    x = downsample(x)
/
x.shape
(2, 128, 8, 8)
len(h)
3
h ##[x0=2 32 32 32  x1=2 64 16 16 x3=2 128 8 8] 
[Tensor(shape=[2, 32, 32, 32], dtype=Float32, value=
 Tensor(shape=[2, 64, 16, 16], dtype=Float32, value=
 Tensor(shape=[2, 128, 8, 8], dtype=Float32, value=
len_h = len(h) - 1
len_h
2
h[len_h].shape,x.shape  ##最后的一个downsample的维度没有变化
((2, 128, 8, 8), (2, 128, 8, 8))
ops.concat((x, h[len_h]), 1)
-


Tensor(shape=[2, 256, 8, 8], dtype=Float32, 

这段代码使用了MindSpore框架中的ops.concat函数来执行张量拼接操作。下面是对该代码片段的详细解析:

  • ops.concat: 这是MindSpore操作库中的一个函数,用于沿着指定维度拼接一个张量列表。这里的"ops"是MindSpore中操作(operations)的简写,用于访问各种数学和数组操作。

  • (x, h[len_h]): 这是一个包含两个张量的元组,它们是要被拼接的输入。其中:

    • x 是一个张量。
    • h[len_h] 表示从列表或数组h中获取索引为len_h的元素。这通常意味着取h的最后一个元素,如果len_hh的长度的话。不过,确切的行为依据len_h的具体值而定,如果len_h是动态计算的结果或者代表序列的长度,则它可能不是简单地指最后一个元素,而是某个特定位置的元素。
  • , 1): 这个参数指定了拼接操作应该沿着第1个维度进行。在MindSpore和其他类似的深度学习库中,维度计数通常从0开始,所以1表示第二个维度。这意味着xh[len_h]将在它们的第二个维度上被连接起来,生成一个新的张量,其中这两个输入张量的相应列被串联在一起。

综上所述,这段代码的作用是将张量x和序列h中的最后一个元素(或索引为len_h的元素)在第二个维度上进行拼接,从而生成一个新的张量。这样的操作常见于循环神经网络(RNNs)等模型中,用于更新隐藏状态或组合不同来源的信息。

x = mid_block1(x, t)
x = mid_attn(x)
x = mid_block2(x, t)
x.shape
\


(2, 128, 8, 8)
len_h = len(h) - 1
len_h
2
i=0
for block1, block2, attn, upsample in ups:  ##ups 只有2个元素
    i=i+1
    print("--------",i)
    print("BL1块1:",block1,"BL2块2:", block2, "ATTT残差注意力:",attn, "UP上采样:",upsample)
    
    
                        # block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        # block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        # Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        # Upsample(dim_in) if not is_last else nn.Identity(),
    
    
    
    
    
    
### 第一层BL1
#(res_conv): Conv2d<input_channels=256, output_channels=64  (128,64)

### 第一层BL2
#(4): Conv2d<input_channels=128[64*2], output_channels=64, (res_conv): Identity<>  (128,64)

### 第一层Res
#Conv2d<input_channels=128, output_channels=64,,  (128,64)

### 第一层up
#Conv2dTranspose<input_channels=64, output_channels=64, (64,64)

### 第二层BL1
#(res_conv): Conv2d<input_channels=128, output_channels=32  (64,32)

### 第二层BL2
# Conv2d<input_channels=64[32*2], output_channels=32,  (res_conv): Identity<>  (64,32)

### 第二层Res
#Conv2d<input_channels=128, output_channels=32, (64,32)

### 第二层up
#  Conv2dTranspose<input_channels=32, output_channels=32,  (32,32)    
-------- 1
BL1块1: ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=256, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=256, output_channels=256, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=256, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2b80>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=256>
    (1): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee15430>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=128>
    (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee15d30>, bias_init=None, format=NCHW>
    >
  (res_conv): Conv2d<input_channels=256, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb24c0>, bias_init=None, format=NCHW>
  > BL2块2: ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=64, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbcd00>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=64>
    (1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc550>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=128>
    (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c4f0>, bias_init=None, format=NCHW>
    >
  (res_conv): Identity<>
  > ATTT残差注意力: Residual<
  (fn): PreNorm<
    (fn): LinearAttention<
      (to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354580>, bias_init=None, format=NCHW>
      (to_out): SequentialCell<
        (0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30cc10>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc30c9a0>, format=NCHW>
        (1): LayerNorm<>
        >
      >
    (norm): GroupNorm<num_groups=1, num_channels=64>
    >
  > UP上采样: Conv2dTranspose<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30cac0>, bias_init=None, format=NCHW>
-------- 2
BL1块1: ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=128, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c2e0>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=128>
    (1): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c5e0>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=64>
    (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17f10>, bias_init=None, format=NCHW>
    >
  (res_conv): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17760>, bias_init=None, format=NCHW>
  > BL2块2: ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=32, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17a00>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=32>
    (1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17580>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=64>
    (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17dc0>, bias_init=None, format=NCHW>
    >
  (res_conv): Identity<>
  > ATTT残差注意力: Residual<
  (fn): PreNorm<
    (fn): LinearAttention<
      (to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3064f0>, bias_init=None, format=NCHW>
      (to_out): SequentialCell<
        (0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306b50>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc3065b0>, format=NCHW>
        (1): LayerNorm<>
        >
      >
    (norm): GroupNorm<num_groups=1, num_channels=32>
    >
  > UP上采样: Conv2dTranspose<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306f70>, bias_init=None, format=NCHW>
class ConvNextBlock(nn.Cell):
    def __init__(self, dim=256, dim_out=64, *, time_emb_dim=128, mult=2, norm=True):
        super().__init__()
        self.mlp = (
            nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
        self.net = nn.SequentialCell(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
        )

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def construct(self, x, time_emb=None):
        h = self.ds_conv(x)
        if exists(self.mlp) and exists(time_emb):
            assert exists(time_emb), "time embedding must be passed in"
            condition = self.mlp(time_emb)
            condition = condition.expand_dims(-1).expand_dims(-1)
            h = h + condition

        h = self.net(h)
        return h + self.res_conv(x)
BL1=ConvNextBlock()
BL1
ConvNextBlock<
  (mlp): SequentialCell<
    (0): GELU<>
    (1): Dense<input_channels=128, output_channels=256, has_bias=True>
    >
  (ds_conv): Conv2d<input_channels=256, output_channels=256, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=256, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffdddd94880>, bias_init=None, format=NCHW>
  (net): SequentialCell<
    (0): GroupNorm<num_groups=1, num_channels=256>
    (1): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffdb504af10>, bias_init=None, format=NCHW>
    (2): GELU<>
    (3): GroupNorm<num_groups=1, num_channels=128>
    (4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffdb53c3d60>, bias_init=None, format=NCHW>
    >
  (res_conv): Conv2d<input_channels=256, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffdcc3e9970>, bias_init=None, format=NCHW>
  >

for block1, block2, attn, upsample in ups:
    x = ops.concat((x, h[len_h]), 1)
    len_h -= 1
    x = block1(x, t)
    x = block2(x, t)
    x = attn(x)

    x = upsample(x)
x.shape
(2, 32, 32, 32)
rx=final_conv(x)
rx.shape
(2, 3, 32, 32)
    def construct(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        len_h = len(h) - 1
        for block1, block2, attn, upsample in self.ups: ###因为up只有2个元素 down 有3个元素  但是我们在这里只是循环2次 并没有取出h[0]  就是downn 最开始的那个(20,32)
            x = ops.concat((x, h[len_h]), 1) ##这步是有啥作用?就是传说的skip connect 或者所谓的残差?
            len_h -= 1
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)
        return self.final_conv(x)

这段代码定义了一个基于U-Net架构的模型,主要应用于图像处理、图像生成或分割任务中,特别是在需要保留细节信息同时捕捉上下文特征的场景下。此模型通过编码器-解码器结构,结合跳跃连接(skip connections)来实现这一点。下面是逐部分的解析:

初始化与时间嵌入

  • self.init_conv(x):对输入x应用初始卷积层,开始特征提取。
  • t = self.time_mlp(time):如果模型设计中包含时间相关的处理(常用于时序数据或在生成模型中引入时间条件),则通过多层感知机(MLP)处理时间信号time,得到时间嵌入t

编码器路径(Downsampling)

  • 循环遍历self.downs中的模块(每个模块包含两个卷积块block1, block2、一个注意力模块attn和一个下采样模块downsample):
    • 两个卷积块分别应用特征变换,并可选择性地结合时间嵌入t
    • 应用注意力机制模块attn增强特征表示。
    • 将当前特征图x添加到列表h中作为跳跃连接的存储。
    • 使用下采样模块减小空间尺寸,增加深度。

中间块(Bottleneck)

  • 应用一系列中间层变换,包括两个卷积块和一个注意力模块,进一步提炼特征。

解码器路径(Upsampling)

  • 反向循环遍历self.ups中的模块,与编码器部分相对应,但包含上采样操作:
    • ops.concat((x, h[len_h]), 1):这是关键的跳跃连接步骤,将当前解码器层的输出x与对应编码器层的特征图h[len_h]沿通道维度(维度1)拼接,从而传递并合并局部细节信息。
    • 减少len_h以在下一轮迭代中获取上一层的特征图。
    • 继续应用两个卷积块和注意力模块,以及上采样操作,逐步增加空间尺寸并整合信息。

输出

  • 最后,通过self.final_conv(x)应用最终的卷积层,生成输出特征图或直接预测像素级结果。

跳跃连接(Skip Connections)的作用

  • 跳跃连接(在这里体现为特征图的拼接)有助于解决梯度消失问题,使得网络能更有效地学习细节信息。
  • 它允许低级特征(保留了更多细节信息)与高级特征(提供了更多上下文信息)在解码阶段融合,这对于恢复输入的精细结构至关重要,尤其是在图像生成和分割任务中。
  • 因此,这种设计不仅有助于保持对输入细节的精确重构,还能促进生成内容的高保真度和清晰度。
final_conv  ##恢复到3个channel
SequentialCell<
  (0): ConvNextBlock<
    (ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2acb50>, bias_init=None, format=NCHW>
    (net): SequentialCell<
      (0): GroupNorm<num_groups=1, num_channels=32>
      (1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17bb0>, bias_init=None, format=NCHW>
      (2): GELU<>
      (3): GroupNorm<num_groups=1, num_channels=64>
      (4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2100>, bias_init=None, format=NCHW>
      >
    (res_conv): Identity<>
    >
  (1): Conv2d<input_channels=32, output_channels=3, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2eb0>, bias_init=None, format=NCHW>
  >



网站公告

今日签到

点亮在社区的每一天
去签到