言简意赅的讲解TensorFlow+卷积神经网络(CNN)解决的痛点
项目概览
垃圾分类是实现可持续发展的重要环节,本教程通过TensorFlow+经典的卷积神经网络(CNN)示例,带你从环境配置到单图推理全流程落地:无需繁琐背景,只讲关键步骤,快速构建高效、可解释的自动化分类系统。如果读文章的同学想一键拥有和我一样的环境的话可以先部署Conda,有疑问的话可以读之前文章👉零基础上手Conda:安装、创建环境、管理依赖的完整指南
- 环境管理:
environment.yml
一键复现 - 数据集准备:下载链接与目录结构
- 数据清洗:自动删除损坏图片
- 数据增强:提升模型鲁棒性
- 模型搭建与训练:CNN 架构详解
- 训练过程可视化:Loss/Accuracy 曲线
- 单图推理:实时分类与可解释分析
- CNN vs. Transformer 对比:架构选型指南
一、环境管理
在项目根目录中创建一个名为 environment.yml
的文件,内容示例如下:
name: tf_gpu
channels:
- defaults
- conda-forge
dependencies:
- _openmp_mutex=4.5
- blas=1.0
- brotli-python=1.0.9
- bzip2=1.0.8
- ca-certificates=2025.4.26
- contourpy=1.3.1
- cudatoolkit=11.2.2
- cudnn=8.1.0.77
- cycler=0.11.0
- expat=2.7.1
- fonttools=4.55.3
- freetype=2.13.3
- glib=2.84.0
- glib-tools=2.84.0
- gst-plugins-base=1.24.7
- gstreamer=1.24.7
- icc_rt=2022.1.0
- icu=75.1
- intel-openmp=2023.2.0
- joblib=1.4.2
- kiwisolver=1.4.8
- krb5=1.21.3
- lcms2=2.17
- lerc=4.0.0
- libblas=3.9.0
- libcblas=3.9.0
- libclang13=20.1.7
- libdeflate=1.24
- libffi=3.4.4
- libfreetype=2.13.3
- libfreetype6=2.13.3
- libgcc=15.1.0
- libglib=2.84.0
- libgomp=15.1.0
- libhwloc=2.11.2
- libiconv=1.18
- libintl=0.22.5
- libintl-devel=0.22.5
- libjpeg-turbo=3.1.0
- liblapack=3.9.0
- liblzma=5.8.1
- liblzma-devel=5.8.1
- libogg=1.3.5
- libpng=1.6.47
- libsqlite=3.50.1
- libtiff=4.7.0
- libvorbis=1.3.7
- libwebp-base=1.5.0
- libwinpthread=12.0.0.r4.gg4f2fc60ca
- libxcb=1.17.0
- libxml2=2.13.8
- libzlib=1.3.1
- matplotlib=3.10.0
- matplotlib-base=3.10.0
- mkl=2023.2.0
- mkl-service=2.4.1
- openjpeg=2.5.3
- openssl=3.5.0
- pcre2=10.44
- pillow=11.2.1
- pip=25.1
- ply=3.11
- pthread-stubs=0.4
- pyparsing=3.2.0
- pyqt=5.15.10
- pyqt5-sip=12.13.0
- python=3.10.16
- python-dateutil=2.9.0post0
- python_abi=3.10
- qt-main=5.15.15
- scikit-learn=1.6.1
- setuptools=78.1.1
- sip=6.7.12
- six=1.17.0
- sqlite=3.45.3
- tbb=2021.13.0
- threadpoolctl=3.5.0
- tk=8.6.13
- tomli=2.0.1
- tornado=6.5.1
- tzdata=2025b
- ucrt=10.0.22621.0
- unicodedata2=15.1.0
- vc=14.42
- vc14_runtime=14.42.34438
- vs2015_runtime=14.42.34438
- wheel=0.45.1
- xorg-libxau=1.0.12
- xorg-libxdmcp=1.1.5
- xz=5.8.1
- xz-tools=5.8.1
- zlib=1.3.1
- zstd=1.5.7
- pip:
- absl-py==2.3.0
- astunparse==1.6.3
- cachetools==5.5.2
- certifi==2025.4.26
- charset-normalizer==3.4.2
- flatbuffers==25.2.10
- gast==0.4.0
- google-auth==2.40.3
- google-auth-oauthlib==0.4.6
- google-pasta==0.2.0
- grpcio==1.73.0
- h5py==3.14.0
- idna==3.10
- keras==2.10.0
- keras-preprocessing==1.1.2
- libclang==18.1.1
- markdown==3.8
- markupsafe==3.0.2
- numpy==1.23.5
- oauthlib==3.2.2
- opt-einsum==3.4.0
- packaging==25.0
- protobuf==3.19.6
- pyasn1==0.6.1
- pyasn1-modules==0.4.2
- requests==2.32.4
- requests-oauthlib==2.0.0
- rsa==4.9.1
- scipy==1.15.3
- tensorboard==2.10.1
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- tensorflow==2.10.0
- tensorflow-estimator==2.10.0
- tensorflow-io-gcs-filesystem==0.31.0
- termcolor==3.1.0
- typing-extensions==4.14.0
- urllib3==2.4.0
- werkzeug==3.1.3
- wrapt==1.17.2
prefix: C:\Users\Wenhao\.conda\envs\tf_gpu
一键创建环境
conda env create -f environment.yml conda activate garbage_classify
2. 数据集准备
2.1 下载与解压
- 来源:阿里云天池【垃圾分类数据集】
https://tianchi.aliyun.com/dataset/138860
2.2 目录结构
project-root/
├── dataset/
│ ├── Harmful/ # 有害垃圾
│ ├── Kitchen/ # 厨余垃圾
│ ├── Other/ # 其他垃圾
│ └── Recyclable/ # 可回收垃圾
├── clean_data.py
├── train.py
├── visualize.py
├── predict.py
└── environment.yml
3. 数据清洗
3.1 目的
- 自动剔除打不开或截断的图片,避免训练中断。
3.2 实现
# clean_data.py
import os
from PIL import Image, ImageFile
# 支持加载截断图
ImageFile.LOAD_TRUNCATED_IMAGES = True
DATA_DIR = "dataset/"
bad_images = []
for root, _, files in os.walk(DATA_DIR):
for fname in files:
path = os.path.join(root, fname)
try:
with Image.open(path) as img:
img.verify()
except:
bad_images.append(path)
if bad_images:
print(f"删除 {len(bad_images)} 张损坏图片:")
for p in bad_images:
os.remove(p)
print(" ✔", p)
else:
print("✅ 未检测到损坏图片")
python clean_data.py
4. 数据增强
4.1 增强化技巧
- 几何变换:旋转、平移、剪切、缩放
- 颜色变换:亮度、通道抖动
- 翻转与填充:水平翻转 + 边界反射
4.2 代码示例
# train.py 中的数据生成部分
from tensorflow.keras.preprocessing.image import ImageDataGenerator
IMAGE_SIZE = (128, 128)
BATCH_SIZE = 32
datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2,
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=10,
zoom_range=0.2,
brightness_range=[0.8,1.2],
channel_shift_range=15,
horizontal_flip=True,
fill_mode='reflect'
)
train_gen = datagen.flow_from_directory(
"dataset/",
target_size=IMAGE_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='training'
)
val_gen = datagen.flow_from_directory(
"dataset/",
target_size=IMAGE_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='validation'
)
5. 模型搭建与训练
5.1 模型架构
- 卷积层 + 池化层:提取多层次特征
- 批归一化:稳定加速训练
- 全局平均池化:参数少、防过拟合
- 全连接 + Dropout:分类输出
5.2 训练脚本
# train.py
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
Conv2D, BatchNormalization, MaxPooling2D,
GlobalAveragePooling2D, Dense, Dropout
)
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
print("✅ GPU:", tf.config.list_physical_devices('GPU'))
num_classes = train_gen.num_classes
model = Sequential([
Conv2D(32,3,activation='relu',input_shape=(128,128,3)),
BatchNormalization(), MaxPooling2D(),
Conv2D(64,3,activation='relu'),
BatchNormalization(), MaxPooling2D(),
Conv2D(128,3,activation='relu'),
BatchNormalization(), MaxPooling2D(),
GlobalAveragePooling2D(),
Dense(128,activation='relu'),
Dropout(0.5),
Dense(num_classes,activation='softmax'),
])
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-4),
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.summary()
callbacks = [
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
]
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=50,
callbacks=callbacks
)
model.save("custom_garbage_classifier.h5")
print("✅ 模型保存至 custom_garbage_classifier.h5")
6. 训练过程可视化
# visualize.py
import matplotlib.pyplot as plt
# Loss 曲线
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.title("Loss 曲线")
plt.legend()
plt.show()
# Accuracy 曲线
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='val_acc')
plt.title("Accuracy 曲线")
plt.legend()
plt.show()
训练过程完整代码
import os
from PIL import Image, ImageFile
# 允许 Pillow 加载被截断的图片
ImageFile.LOAD_TRUNCATED_IMAGES = True
# 数据集路径
DATA_DIR = "dataset/"
# 第一:自动清理所有损坏或截断的图片
bad_images = []
for root, _, files in os.walk(DATA_DIR):
for fname in files:
path = os.path.join(root, fname)
try:
with Image.open(path) as img:
img.verify()
except Exception:
bad_images.append(path)
if bad_images:
print(f"Found {len(bad_images)} bad images. Removing…")
for p in bad_images:
os.remove(p)
print(" Removed", p)
else:
print("No corrupted images found.")
# —— 下面是你的训练脚本 —— #
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
# 检查 GPU 是否可用
print("✅ GPU 设备列表:", tf.config.list_physical_devices('GPU'))
# 参数
IMAGE_SIZE = (128, 128)
BATCH_SIZE = 32
EPOCHS = 15
# 数据增强 + 预处理
datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2,
rotation_range=15,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.1,
horizontal_flip=True
)
train_gen = datagen.flow_from_directory(
DATA_DIR,
target_size=IMAGE_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='training'
)
val_gen = datagen.flow_from_directory(
DATA_DIR,
target_size=IMAGE_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='validation'
)
# 自定义 CNN 模型结构
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3)),
MaxPooling2D(2, 2),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Conv2D(128, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Flatten(),
Dense(128, activation='relu'),
Dropout(0.5),
Dense(train_gen.num_classes, activation='softmax')
])
# 编译模型
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
# 模型结构
model.summary()
# 增加 EarlyStopping,防止过拟合
from tensorflow.keras.callbacks import EarlyStopping
early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
# 模型训练
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=EPOCHS,
callbacks=[early_stop]
)
# 模型保存
model.save("custom_garbage_classifier.h5")
print("✅ 模型训练完成并保存为 custom_garbage_classifier.h5")
7. 单图推理与可解释 AI
# predict.py
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
model = load_model("custom_garbage_classifier.h5")
img_path = "evalImageSet/5.jpg"
IMG_SIZE = (128, 128)
img = load_img(img_path, target_size=IMG_SIZE)
x = img_to_array(img)/255.0
x = np.expand_dims(x,0)
probs = model.predict(x)[0]
class_idx = np.argmax(probs)
class_indices = {'Harmful':0,'Kitchen':1,'Other':2,'Recyclable':3}
labels = {v:k for k,v in class_indices.items()}
print(f"▶ {img_path} → {labels[class_idx]} ({probs[class_idx]:.1%})")
print("各类别概率:")
for i,p in enumerate(probs):
print(f" {labels[i]:<12}: {p:.2%}")
plt.imshow(img)
plt.title(f"{labels[class_idx]} ({probs[class_idx]:.1%})")
plt.axis('off')
plt.show()
- 可选:Grad-CAM 可视化关注区域。
推理过程完整代码
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
# 1. 加载模型
model = load_model("custom_garbage_classifier.h5")
# 2. 指定要分析的图片路径
img_path = "evalImageSet/5.jpg" # 改成你自己的图片
# 3. 载入并预处理
IMG_SIZE = (128, 128)
img = load_img(img_path, target_size=IMG_SIZE)
x = img_to_array(img) / 255.0 # 归一化到 [0,1]
x = np.expand_dims(x, axis=0) # 变成 (1,128,128,3)
# 4. 预测
probs = model.predict(x)[0] # 得到一个长度为类别数的向量
class_idx = np.argmax(probs) # 预测的类别索引
# 5. 反查类别名称
# 这里假设你有个 class_indices dict,来自训练时的 generator
# 比如:{'glass': 0, 'paper': 1, 'plastic': 2, ...}
# 请替换成你的实际 mapping
class_indices = {'Harmful': 0, 'Kitchen': 1, 'Other': 2, 'Recyclable': 3}
labels = {v:k for k,v in class_indices.items()}
pred_label = labels[class_idx]
pred_prob = probs[class_idx]
# 6. 输出结果
print(f"▶ 分析图片:{img_path}")
print(f"预测类别:{pred_label},置信度:{pred_prob:.4%}")
print("\n各类别概率:")
for idx, p in enumerate(probs):
print(f" {labels[idx]:<8}: {p:.2%}")
# 7. (可选)显示图片
plt.imshow(img)
plt.title(f"Pred: {pred_label} ({pred_prob:.1%})")
plt.axis('off')
plt.show()
8. CNN 与 Transformer 对比
维度 | CNN | Transformer |
---|---|---|
核心模块 | Conv2D + Pooling | Self-Attention + Feed-Forward |
感受野 | 随层级堆叠扩大 | 单层即可实现全局 |
参数共享 | 卷积核在空间/时间上复用 | 注意力权重在所有 token 对上共享 |
位置敏感 | 平移不变;须显式位置编码 | 原生顺序敏感 + 位置编码 |
并行度 | 高(局部并行) | 极高(全局并行) |
计算复杂度 | O(N·K²·C_out) | O(N²·D) |
- 共性:底层张量运算、反向传播、优化器、正则化方法相同。
- 选型:依“局部 vs. 全局依赖”选择;也可混合(ViT、Conformer、CLIP)。
通过上述内容,你就已经基本理解了这个方法,基础用法我也都有展示。如果你能融会贯通,我相信你会很强
Best
Wenhao (楠博万)