TensorFlow+CNN垃圾分类深度学习全流程实战教程

发布于:2025-06-21 ⋅ 阅读:(19) ⋅ 点赞:(0)

言简意赅的讲解TensorFlow+卷积神经网络(CNN)解决的痛点

项目概览

垃圾分类是实现可持续发展的重要环节,本教程通过TensorFlow+经典的卷积神经网络(CNN)示例,带你从环境配置单图推理全流程落地:无需繁琐背景,只讲关键步骤,快速构建高效、可解释的自动化分类系统。如果读文章的同学想一键拥有和我一样的环境的话可以先部署Conda,有疑问的话可以读之前文章👉零基础上手Conda:安装、创建环境、管理依赖的完整指南

  1. 环境管理environment.yml 一键复现
  2. 数据集准备:下载链接与目录结构
  3. 数据清洗:自动删除损坏图片
  4. 数据增强:提升模型鲁棒性
  5. 模型搭建与训练:CNN 架构详解
  6. 训练过程可视化:Loss/Accuracy 曲线
  7. 单图推理:实时分类与可解释分析
  8. CNN vs. Transformer 对比:架构选型指南

一、环境管理

在项目根目录中创建一个名为 environment.yml 的文件,内容示例如下:

name: tf_gpu
channels:
  - defaults
  - conda-forge
dependencies:
  - _openmp_mutex=4.5
  - blas=1.0
  - brotli-python=1.0.9
  - bzip2=1.0.8
  - ca-certificates=2025.4.26
  - contourpy=1.3.1
  - cudatoolkit=11.2.2
  - cudnn=8.1.0.77
  - cycler=0.11.0
  - expat=2.7.1
  - fonttools=4.55.3
  - freetype=2.13.3
  - glib=2.84.0
  - glib-tools=2.84.0
  - gst-plugins-base=1.24.7
  - gstreamer=1.24.7
  - icc_rt=2022.1.0
  - icu=75.1
  - intel-openmp=2023.2.0
  - joblib=1.4.2
  - kiwisolver=1.4.8
  - krb5=1.21.3
  - lcms2=2.17
  - lerc=4.0.0
  - libblas=3.9.0
  - libcblas=3.9.0
  - libclang13=20.1.7
  - libdeflate=1.24
  - libffi=3.4.4
  - libfreetype=2.13.3
  - libfreetype6=2.13.3
  - libgcc=15.1.0
  - libglib=2.84.0
  - libgomp=15.1.0
  - libhwloc=2.11.2
  - libiconv=1.18
  - libintl=0.22.5
  - libintl-devel=0.22.5
  - libjpeg-turbo=3.1.0
  - liblapack=3.9.0
  - liblzma=5.8.1
  - liblzma-devel=5.8.1
  - libogg=1.3.5
  - libpng=1.6.47
  - libsqlite=3.50.1
  - libtiff=4.7.0
  - libvorbis=1.3.7
  - libwebp-base=1.5.0
  - libwinpthread=12.0.0.r4.gg4f2fc60ca
  - libxcb=1.17.0
  - libxml2=2.13.8
  - libzlib=1.3.1
  - matplotlib=3.10.0
  - matplotlib-base=3.10.0
  - mkl=2023.2.0
  - mkl-service=2.4.1
  - openjpeg=2.5.3
  - openssl=3.5.0
  - pcre2=10.44
  - pillow=11.2.1
  - pip=25.1
  - ply=3.11
  - pthread-stubs=0.4
  - pyparsing=3.2.0
  - pyqt=5.15.10
  - pyqt5-sip=12.13.0
  - python=3.10.16
  - python-dateutil=2.9.0post0
  - python_abi=3.10
  - qt-main=5.15.15
  - scikit-learn=1.6.1
  - setuptools=78.1.1
  - sip=6.7.12
  - six=1.17.0
  - sqlite=3.45.3
  - tbb=2021.13.0
  - threadpoolctl=3.5.0
  - tk=8.6.13
  - tomli=2.0.1
  - tornado=6.5.1
  - tzdata=2025b
  - ucrt=10.0.22621.0
  - unicodedata2=15.1.0
  - vc=14.42
  - vc14_runtime=14.42.34438
  - vs2015_runtime=14.42.34438
  - wheel=0.45.1
  - xorg-libxau=1.0.12
  - xorg-libxdmcp=1.1.5
  - xz=5.8.1
  - xz-tools=5.8.1
  - zlib=1.3.1
  - zstd=1.5.7
  - pip:
      - absl-py==2.3.0
      - astunparse==1.6.3
      - cachetools==5.5.2
      - certifi==2025.4.26
      - charset-normalizer==3.4.2
      - flatbuffers==25.2.10
      - gast==0.4.0
      - google-auth==2.40.3
      - google-auth-oauthlib==0.4.6
      - google-pasta==0.2.0
      - grpcio==1.73.0
      - h5py==3.14.0
      - idna==3.10
      - keras==2.10.0
      - keras-preprocessing==1.1.2
      - libclang==18.1.1
      - markdown==3.8
      - markupsafe==3.0.2
      - numpy==1.23.5
      - oauthlib==3.2.2
      - opt-einsum==3.4.0
      - packaging==25.0
      - protobuf==3.19.6
      - pyasn1==0.6.1
      - pyasn1-modules==0.4.2
      - requests==2.32.4
      - requests-oauthlib==2.0.0
      - rsa==4.9.1
      - scipy==1.15.3
      - tensorboard==2.10.1
      - tensorboard-data-server==0.6.1
      - tensorboard-plugin-wit==1.8.1
      - tensorflow==2.10.0
      - tensorflow-estimator==2.10.0
      - tensorflow-io-gcs-filesystem==0.31.0
      - termcolor==3.1.0
      - typing-extensions==4.14.0
      - urllib3==2.4.0
      - werkzeug==3.1.3
      - wrapt==1.17.2
prefix: C:\Users\Wenhao\.conda\envs\tf_gpu
  • 一键创建环境

    conda env create -f environment.yml
    conda activate garbage_classify
    

2. 数据集准备

2.1 下载与解压

2.2 目录结构

project-root/
├── dataset/
│   ├── Harmful/       # 有害垃圾
│   ├── Kitchen/       # 厨余垃圾
│   ├── Other/         # 其他垃圾
│   └── Recyclable/    # 可回收垃圾
├── clean_data.py
├── train.py
├── visualize.py
├── predict.py
└── environment.yml

3. 数据清洗

3.1 目的

  • 自动剔除打不开或截断的图片,避免训练中断。

3.2 实现

# clean_data.py
import os
from PIL import Image, ImageFile

# 支持加载截断图
ImageFile.LOAD_TRUNCATED_IMAGES = True
DATA_DIR = "dataset/"
bad_images = []

for root, _, files in os.walk(DATA_DIR):
    for fname in files:
        path = os.path.join(root, fname)
        try:
            with Image.open(path) as img:
                img.verify()
        except:
            bad_images.append(path)

if bad_images:
    print(f"删除 {len(bad_images)} 张损坏图片:")
    for p in bad_images:
        os.remove(p)
        print("  ✔", p)
else:
    print("✅ 未检测到损坏图片")
python clean_data.py

4. 数据增强

4.1 增强化技巧

  • 几何变换:旋转、平移、剪切、缩放
  • 颜色变换:亮度、通道抖动
  • 翻转与填充:水平翻转 + 边界反射

4.2 代码示例

# train.py 中的数据生成部分
from tensorflow.keras.preprocessing.image import ImageDataGenerator

IMAGE_SIZE = (128, 128)
BATCH_SIZE = 32

datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=10,
    zoom_range=0.2,
    brightness_range=[0.8,1.2],
    channel_shift_range=15,
    horizontal_flip=True,
    fill_mode='reflect'
)

train_gen = datagen.flow_from_directory(
    "dataset/",
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training'
)
val_gen = datagen.flow_from_directory(
    "dataset/",
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation'
)

5. 模型搭建与训练

5.1 模型架构

  1. 卷积层 + 池化层:提取多层次特征
  2. 批归一化:稳定加速训练
  3. 全局平均池化:参数少、防过拟合
  4. 全连接 + Dropout:分类输出

5.2 训练脚本

# train.py
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
    Conv2D, BatchNormalization, MaxPooling2D,
    GlobalAveragePooling2D, Dense, Dropout
)
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

print("✅ GPU:", tf.config.list_physical_devices('GPU'))

num_classes = train_gen.num_classes
model = Sequential([
    Conv2D(32,3,activation='relu',input_shape=(128,128,3)),
    BatchNormalization(), MaxPooling2D(),

    Conv2D(64,3,activation='relu'),
    BatchNormalization(), MaxPooling2D(),

    Conv2D(128,3,activation='relu'),
    BatchNormalization(), MaxPooling2D(),

    GlobalAveragePooling2D(),
    Dense(128,activation='relu'),
    Dropout(0.5),
    Dense(num_classes,activation='softmax'),
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)
model.summary()

callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
]

history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=50,
    callbacks=callbacks
)

model.save("custom_garbage_classifier.h5")
print("✅ 模型保存至 custom_garbage_classifier.h5")

6. 训练过程可视化

# visualize.py
import matplotlib.pyplot as plt

# Loss 曲线
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.title("Loss 曲线")
plt.legend()
plt.show()

# Accuracy 曲线
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='val_acc')
plt.title("Accuracy 曲线")
plt.legend()
plt.show()

训练过程完整代码

import os
from PIL import Image, ImageFile

# 允许 Pillow 加载被截断的图片
ImageFile.LOAD_TRUNCATED_IMAGES = True

# 数据集路径
DATA_DIR = "dataset/"

# 第一:自动清理所有损坏或截断的图片
bad_images = []
for root, _, files in os.walk(DATA_DIR):
    for fname in files:
        path = os.path.join(root, fname)
        try:
            with Image.open(path) as img:
                img.verify()
        except Exception:
            bad_images.append(path)

if bad_images:
    print(f"Found {len(bad_images)} bad images. Removing…")
    for p in bad_images:
        os.remove(p)
        print("  Removed", p)
else:
    print("No corrupted images found.")

# —— 下面是你的训练脚本 —— #

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

# 检查 GPU 是否可用
print("✅ GPU 设备列表:", tf.config.list_physical_devices('GPU'))

# 参数
IMAGE_SIZE = (128, 128)
BATCH_SIZE = 32
EPOCHS = 15

# 数据增强 + 预处理
datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True
)

train_gen = datagen.flow_from_directory(
    DATA_DIR,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training'
)

val_gen = datagen.flow_from_directory(
    DATA_DIR,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation'
)

# 自定义 CNN 模型结构
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3)),
    MaxPooling2D(2, 2),

    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),

    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),

    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(train_gen.num_classes, activation='softmax')
])

# 编译模型
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# 模型结构
model.summary()

# 增加 EarlyStopping,防止过拟合
from tensorflow.keras.callbacks import EarlyStopping
early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

# 模型训练
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS,
    callbacks=[early_stop]
)

# 模型保存
model.save("custom_garbage_classifier.h5")
print("✅ 模型训练完成并保存为 custom_garbage_classifier.h5")

模型训练截图


7. 单图推理与可解释 AI

# predict.py
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array

model = load_model("custom_garbage_classifier.h5")
img_path = "evalImageSet/5.jpg"
IMG_SIZE = (128, 128)

img = load_img(img_path, target_size=IMG_SIZE)
x   = img_to_array(img)/255.0
x   = np.expand_dims(x,0)

probs     = model.predict(x)[0]
class_idx = np.argmax(probs)

class_indices = {'Harmful':0,'Kitchen':1,'Other':2,'Recyclable':3}
labels = {v:k for k,v in class_indices.items()}

print(f"▶ {img_path}{labels[class_idx]} ({probs[class_idx]:.1%})")
print("各类别概率:")
for i,p in enumerate(probs):
    print(f"  {labels[i]:<12}: {p:.2%}")

plt.imshow(img)
plt.title(f"{labels[class_idx]} ({probs[class_idx]:.1%})")
plt.axis('off')
plt.show()
  • 可选:Grad-CAM 可视化关注区域。

推理过程完整代码

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array

# 1. 加载模型
model = load_model("custom_garbage_classifier.h5")

# 2. 指定要分析的图片路径
img_path = "evalImageSet/5.jpg"  # 改成你自己的图片

# 3. 载入并预处理
IMG_SIZE = (128, 128)
img = load_img(img_path, target_size=IMG_SIZE)
x   = img_to_array(img) / 255.0       # 归一化到 [0,1]
x   = np.expand_dims(x, axis=0)       # 变成 (1,128,128,3)

# 4. 预测
probs = model.predict(x)[0]           # 得到一个长度为类别数的向量
class_idx = np.argmax(probs)          # 预测的类别索引

# 5. 反查类别名称
#    这里假设你有个 class_indices dict,来自训练时的 generator
#    比如:{'glass': 0, 'paper': 1, 'plastic': 2, ...}
#    请替换成你的实际 mapping
class_indices = {'Harmful': 0, 'Kitchen': 1, 'Other': 2, 'Recyclable': 3}
labels = {v:k for k,v in class_indices.items()}

pred_label = labels[class_idx]
pred_prob  = probs[class_idx]

# 6. 输出结果
print(f"▶ 分析图片:{img_path}")
print(f"预测类别:{pred_label},置信度:{pred_prob:.4%}")
print("\n各类别概率:")
for idx, p in enumerate(probs):
    print(f"  {labels[idx]:<8}: {p:.2%}")

# 7. (可选)显示图片
plt.imshow(img)
plt.title(f"Pred: {pred_label} ({pred_prob:.1%})")
plt.axis('off')
plt.show()

8. CNN 与 Transformer 对比

维度 CNN Transformer
核心模块 Conv2D + Pooling Self-Attention + Feed-Forward
感受野 随层级堆叠扩大 单层即可实现全局
参数共享 卷积核在空间/时间上复用 注意力权重在所有 token 对上共享
位置敏感 平移不变;须显式位置编码 原生顺序敏感 + 位置编码
并行度 高(局部并行) 极高(全局并行)
计算复杂度 O(N·K²·C_out) O(N²·D)
  • 共性:底层张量运算、反向传播、优化器、正则化方法相同。
  • 选型:依“局部 vs. 全局依赖”选择;也可混合(ViT、Conformer、CLIP)。

效果展示


通过上述内容,你就已经基本理解了这个方法,基础用法我也都有展示。如果你能融会贯通,我相信你会很强

Best
Wenhao (楠博万)


网站公告

今日签到

点亮在社区的每一天
去签到