GEE+本地XGboot分类
我想做提取耕地提取,想到了一篇董金玮老师的一篇论文,这个论文是先提取的耕地,再做作物分类,耕地的提取代码是开源的。
但这个代码直接在云端上进行分类,GEE会爆内存,因此我准备把数据下载到本地,使用GPU加速进行XGboot提取耕地。
董老师的代码涉及到了100多个波段特征,我删减到了45个波段,然后分块进行了数据下载:
数据下载代码:
// ========================================
// 1. 初始化与区域选择
// ========================================
// 选择第一个区域作为AOI
var aoiFeature = fenqu.first();
var aoi = aoiFeature.geometry();
// 可视化AOI(可选)
Map.addLayer(aoi, {color: 'blue'}, 'AOI');
// 中心定位到AOI,缩放级别10(可选)
Map.centerObject(aoi, 10);
// ========================================
// 2. 划分AOI为16个块
// ========================================
// 定义划分块数(4x4网格)
var numCols = 4;
var numRows = 4;
// 获取AOI的边界和范围
var aoiBounds = aoi.bounds();
var coords = ee.List(aoiBounds.coordinates().get(0));
var xMin = ee.Number(ee.List(coords.get(0)).get(0));
var yMin = ee.Number(ee.List(coords.get(0)).get(1));
var xMax = ee.Number(ee.List(coords.get(2)).get(0));
var yMax = ee.Number(ee.List(coords.get(2)).get(1));
// 计算AOI的宽度和高度
var aoiWidth = xMax.subtract(xMin);
var aoiHeight = yMax.subtract(yMin);
// 计算每个块的宽度和高度
var tileWidth = aoiWidth.divide(numCols);
var tileHeight = aoiHeight.divide(numRows);
// 要排除的块的ID
var excludeTiles = ['0_3', '0_2', '3_0']; // 左上角、第二行第一个、右下角
// 生成4x4网格,但排除特定块
var grid = ee.FeatureCollection(
ee.List.sequence(0, numCols - 1).map(function(col) {
return ee.List.sequence(0, numRows - 1).map(function(row) {
var tileId = ee.String(col).cat('_').cat(ee.String(row));
var xmin = xMin.add(tileWidth.multiply(ee.Number(col)));
var ymin = yMin.add(tileHeight.multiply(ee.Number(row)));
var xmax = xmin.add(tileWidth);
var ymax = ymin.add(tileHeight);
var rectangle = ee.Geometry.Rectangle([xmin, ymin, xmax, ymax]);
return ee.Feature(rectangle, {
'tile': tileId
});
});
}).flatten()
).filter(ee.Filter.inList('tile', excludeTiles).not());
// 可视化网格(可选)
Map.addLayer(grid, {color: 'red'}, 'Grid');
// ========================================
// 3. 定义数据处理和导出函数
// ========================================
function processAndExport(tileFeature) {
var tileID = ee.String(tileFeature.get('tile'));
print('Processing Tile:', tileID);
var region = tileFeature.geometry();
// 2. 定义时间范围、波段及区域
var year = 2023;
var startDate = ee.Date.fromYMD(year, 1, 1);
var endDate = ee.Date.fromYMD(year, 12, 31);
var bands = ['B2', 'B3', 'B4', 'B8']; // 蓝、绿、红、近红外
// 3. 云掩膜函数:基于SCL波段
function maskS2clouds(image) {
var scl = image.select('SCL');
// SCL分类值: 3(云)、8(阴影云)
var cloudMask = scl.neq(3).and(scl.neq(8));
return image.updateMask(cloudMask)
.clip(region)
.copyProperties(image, ["system:time_start"]);
}
// 4. 添加光谱指数函数
function addSpectralIndices(image) {
// 计算NDVI
var ndvi = image.normalizedDifference(['B8', 'B4']).rename('NDVI');
// 计算EVI
var evi = image.expression(
'2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))', {
'NIR': image.select('B8'),
'RED': image.select('B4'),
'BLUE': image.select('B2')
}
).rename('EVI');
// 计算GNDVI
var gndvi = image.normalizedDifference(['B8', 'B3']).rename('GNDVI');
// 计算SAVI
var savi = image.expression(
'((NIR - RED) / (NIR + RED + 0.5)) * 1.5', {
'NIR': image.select('B8'),
'RED': image.select('B4')
}
).rename('SAVI');
// 计算MSAVI2
var msavi2 = image.expression(
'0.5 * (2 * NIR + 1 - sqrt((2 * NIR + 1)**2 - 8 * (NIR - RED)))', {
'NIR': image.select('B8'),
'RED': image.select('B4')
}
).rename('MSAVI2');
// 计算NDWI
var ndwi = image.normalizedDifference(['B3', 'B8']).rename('NDWI');
// 计算NDSI
var ndsi = image.normalizedDifference(['B3', 'B11']).rename('NDSI');
// 计算NDSVI
var ndsvi = image.normalizedDifference(['B11', 'B4']).rename('NDSVI');
// 计算NDTI
var ndti = image.normalizedDifference(['B11', 'B12']).rename('NDTI');
// 计算RENDVI
var rendvi = image.normalizedDifference(['B8', 'B5']).rename('RENDVI');
// 计算REP
var rep = image.expression(
'(705 + 35 * ((0.5 * (B6 + B4) - B2) / (B5 - B2))) / 1000', {
'B2': image.select('B2'),
'B4': image.select('B4'),
'B5': image.select('B5'),
'B6': image.select('B6'),
'B8': image.select('B8')
}
).rename('REP');
// 添加所有计算的波段
return image.addBands([ndvi, evi, gndvi, savi, msavi2, ndwi, ndsi, ndsvi, ndti, rendvi, rep]);
}
// 5. 加载并预处理Sentinel-2 L2A影像集合
var sentinel = ee.ImageCollection("COPERNICUS/S2_SR"); // 确保使用正确的Sentinel-2影像集合
var s2 = sentinel
.filterBounds(region)
.filterDate(startDate, endDate)
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 20)) // 初步云量过滤
.map(maskS2clouds)
.map(addSpectralIndices)
.select(['B2', 'B3', 'B4', 'B8', 'NDVI', 'EVI', 'GNDVI', 'SAVI', 'MSAVI2', 'NDWI', 'NDSI', 'NDSVI', 'NDTI', 'RENDVI', 'REP']);
// 6. 计算月度NDVI最大值
var months = ee.List.sequence(1, 12);
var monthlyMaxNDVI = months.map(function(month) {
var monthStart = ee.Date.fromYMD(year, month, 1);
var monthEnd = monthStart.advance(1, 'month');
var monthlyNDVI = s2
.filterDate(monthStart, monthEnd)
.select('NDVI')
.max();
// 使用 ee.String 和 .cat() 正确拼接字符串
var bandName = ee.String('NDVI_month_').cat(ee.Number(month).format('%02d'));
return monthlyNDVI.rename(bandName);
});
print(monthlyMaxNDVI,"monthlyMaxNDVI" )
// 将所有月份的最大NDVI合并为一个图像
var monthlyMaxNDVIImage = ee.Image.cat.apply(null, monthlyMaxNDVI)
print(monthlyMaxNDVIImage,"monthlyMaxNDVIImage" )
// 7. 提取年度统计特征
var Year_Bands = ['B2', 'B3', 'B4', 'B8', 'NDVI', 'EVI', 'GNDVI', 'SAVI', 'MSAVI2', 'NDWI', 'NDSI', 'NDSVI', 'NDTI', 'RENDVI', 'REP'];
var annualStats = s2.select(Year_Bands)
.reduce(ee.Reducer.mean()
.combine(ee.Reducer.max(), null, true)
.combine(ee.Reducer.stdDev(), null, true));
// 重命名年度统计特征的波段
var statNames = ['mean', 'max', 'stdDev'];
var newBandNames = [];
Year_Bands.forEach(function(band) {
statNames.forEach(function(stat) {
newBandNames.push(band + '_' + stat);
});
});
annualStats = annualStats.rename(newBandNames);
// 将月度NDVI最大值和年度统计特征合并
annualStats = ee.Image.cat([annualStats, monthlyMaxNDVIImage]);
// 9. 合并所有特征
var finalImage = ee.Image.cat([annualStats])
.clip(region);
// 可视化示例(可选)
// Map.addLayer(finalImage.select('NDVI_seasonal'), {min: 0, max: 1, palette: ['white', 'green']}, 'NDVI Seasonal');
// 10. 导出数据到Google Drive
var output_name='tile_' + tileID.getInfo()
var name2=output_name.replace('.', '').replace('.', '')
print(finalImage.toFloat())
Export.image.toDrive({
image: finalImage.toFloat(),
description: name2,
scale: 10,
folder: "download_tiles",
region: region,
maxPixels: 1e13
});
}
// ========================================
// 4. 应用函数到每个块
// ========================================
// 注意:Google Earth Engine 同时只能运行有限的Export任务(通常为3个)。
// 因此,建议分批次运行或手动触发每个块的导出任务。
// 将网格转换为特征集合列表
var gridFeatures = grid.toList(grid.size());
// 获取总块数
var totalTiles = grid.size().getInfo();
// 定义每批次导出的数量(如果需要批量控制,可以在这里调整)
var batchSize = 1;
// 处理并导出每个块
// 注意:Google Earth Engine 不支持并行启动大量导出任务,请手动管理导出任务
gridFeatures.evaluate(function(list) {
list.forEach(function(feature) {
processAndExport(ee.Feature(feature));
});
});
// 打印总块数和导出说明
print('Total tiles:', totalTiles);
print('导出已启动。请在任务管理器中检查导出状态。');
然后下载完成后,用gdal做一下镶嵌(设置tile为256,LZW压缩),波段太多,导致数据非常大。最好再做一个金字塔
import os
from osgeo import gdal
# 输入和输出路径
input_dir = r"几十个波段数据"
output_file = "mosaic_result_gdal.tif"
# 获取所有tif文件
tif_files = []
for file in os.listdir(input_dir):
if file.endswith('.tif'):
tif_files.append(os.path.join(input_dir, file))
# 构建VRT
vrt = gdal.BuildVRT("temp.vrt", tif_files)
vrt = None
# 转换VRT为GeoTiff
gdal.Translate(
output_file,
"temp.vrt",
format="GTiff",
creationOptions=[
"COMPRESS=LZW",
"TILED=YES",
"BLOCKXSIZE=256",
"BLOCKYSIZE=256",
"BIGTIFF=YES"
]
)
镶嵌完,可以放进GIS软件中查看一下。
数据分类
在此之前,需要先准备点数据,我是准备了两个点数据矢量(耕地矢量和非耕地矢量),字段属性crop为1代表耕地,0代表非耕地。如果你是做多类别,你可以多做几个矢量。
然后开始安装环境:
(1)安装CUDA,用GPU加速运行,也可以CPU,都差不多,xgboot计算量不大;
(2)安装conda,然后使用下面的命令安装环境:
conda create --prefix D:\conda_ENV\xgboot_env python=3.10
conda activate D:\conda_ENV\xgboot_env
conda install -c conda-forge numpy pandas geopandas rasterio scikit-learn tqdm
然后就可以开始分类了,代码如下:
import geopandas as gpd
import rasterio
from rasterio.sample import sample_gen
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import xgboost as xgb
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, roc_auc_score
from tqdm import tqdm # 用于进度指示
# 读取矢量数据
CROP_FILE = r"耕地样本点.shp"
OTHERS_FILE = r"非耕地样本点.shp"
TIF_PATH = r"mosaic_result_gdal.tif"
cropland = gpd.read_file(CROP_FILE)
non_cropland = gpd.read_file(OTHERS_FILE)
cropland['crop'] = 1
non_cropland['crop'] = 0
samples = pd.concat([cropland, non_cropland], ignore_index=True)
with rasterio.open(TIF_PATH) as src:
band_count = src.count
coords = [(point.x, point.y) for point in samples.geometry]
pixel_values = list(src.sample(coords))
pixel_values = np.array(pixel_values)
feature_columns = [f'band_{i+1}' for i in range(band_count)]
features = pd.DataFrame(pixel_values, columns=feature_columns)
features['crop'] = samples['crop'].values
# 保存特征名称以供预测阶段使用
feature_names = feature_columns.copy()
# 数据预处理
features.dropna(inplace=True)
X = features.drop('crop', axis=1)
y = features['crop']
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# 训练模型
dtrain = xgb.DMatrix(X_train, label=y_train, feature_names=feature_names)
dtest = xgb.DMatrix(X_test, label=y_test, feature_names=feature_names)
params = {
'objective': 'binary:logistic',
'tree_method': 'hist', # 修改为 'hist'
'device': 'gpu', # 添加 'device' 参数
'eval_metric': 'auc',
'eta': 0.1,
'max_depth': 10,
'subsample': 0.8,
'colsample_bytree': 0.8,
'seed': 42
}
evallist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 100
print("开始训练模型...")
bst = xgb.train(params, dtrain, num_round, evallist, early_stopping_rounds=10, verbose_eval=True)
print("模型训练完成。\n")
# 评估模型
print("开始评估模型...")
y_pred_prob = bst.predict(dtest)
y_pred = (y_pred_prob > 0.5).astype(int)
accuracy = accuracy_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_pred_prob)
conf_matrix = confusion_matrix(y_test, y_pred)
report = classification_report(y_test, y_pred)
print(f'Accuracy: {accuracy}')
print(f'AUC: {auc}')
print('Confusion Matrix:')
print(conf_matrix)
print('Classification Report:')
print(report)
print("模型评估完成。\n")
# 应用模型进行栅格分类
print("开始进行栅格分类...")
with rasterio.open(TIF_PATH) as src:
profile = src.profile.copy()
profile.update(
dtype=rasterio.uint8,
count=1,
compress='lzw'
)
# 计算窗口总数用于进度指示
windows = list(src.block_windows(1))
total_windows = len(windows)
with rasterio.open('classified.tif', 'w', **profile) as dst:
for ji, window in tqdm(windows, total=total_windows, desc="栅格分类进度"):
data = src.read(window=window)
# data.shape = (bands, height, width)
bands, height, width = data.shape
data = data.reshape(bands, -1).transpose() # shape: (num_pixels, bands)
# 创建 DataFrame 并赋予特征名称
df = pd.DataFrame(data, columns=feature_names)
# 创建 DMatrix
dmatrix = xgb.DMatrix(df, feature_names=feature_names)
# 预测
predictions = bst.predict(dmatrix)
predictions = (predictions > 0.5).astype(np.uint8)
# 重塑为原窗口形状
out_image = predictions.reshape(height, width)
# 写入输出栅格
dst.write(out_image, 1, window=window)
print("栅格分类完成。")
训练完成后,就开始分类了,就出结果了:
自此,从数据下载到分类处理完毕。
样本数据多的话,也可以考虑用CNN,但分类速度比不上xgboot。
参考:
You N , Dong J , Huang J ,et al.The 10-m crop type maps in Northeast China during 2017–2019[J].Scientific Data, 2021, 8(1).DOI:10.1038/s41597-021-00827-9.