#问题描述
欲用CNN进行栅格降尺度,参考了《Global high-resolution total water storage anomalies from self-supervised data assimilation using deep learning algorithms》文献中给的代码(https://gitlab.ethz.ch/spacegeodesy_public/grace_seda/-/tree/main/code_v2019?ref_type=heads%EF%BC%89%E5%9C%A8%E8%BF%9B%E8%A1%8C%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83%EF%BC%88train.py%EF%BC%89%E7%9A%84%E6%97%B6%E5%80%99%E5%87%BA%E7%8E%B0%E4%BA%86%E6%8A%A5%E9%94%99
import numpy as np
from pathlib import Path
from tqdm import tqdm
from astropy.time import Time
import os
from osgeo import gdal
import tensorflow as tf
def genSampleMat(data, idx_sample_tmp):
TimePosition_tmp = TimePosition[idx_sample_tmp, :] ##获取原始数据的时间步数,假设时间维度是第0维。
Sample_tmp = np.zeros((len(idx_sample_tmp), PatchSize, PatchSize))##从TimePosition矩阵中选取对应的时空位置信息
if np.mod(PatchSize, 2) == 0: # # 检查PatchSize是否为偶数。如果是,则中心像素定义为中心4个像素的左上角。
for i_sample in range(0, len(idx_sample_tmp)):
idx_time_tmp = MJD == TimePosition_tmp[i_sample, 0] ## 从时空位置信息TimePosition_tmp中提取Modified Julian Date (MJD)。 #找到当前样本点对应的时间步索引。
# #获取经纬度
lat_c = TimePosition_tmp[i_sample, 1]
lon_c = TimePosition_tmp[i_sample, 2]
# Double-side difference --> Case odd number of pixels with clear defined central pixel
# # 根据样本点的经纬度和PatchSize及Resolution,确定需要从原始数据中提取的patch的纬度和经度范围
idx_lat_tmp = np.where(abs(lat - lat_c) <= PatchSize/2*Resolution)[0]
idx_lon_tmp = np.where(abs(lon - lon_c) <= PatchSize/2*Resolution)[0]
# We first concatenate longitude, then latitude is just nan with fixed longitutde size
# 从原始数据中提取patchWe first concatenate longitude, then latitude is just nan with fixed longitutde size
Patch_tmp = data[idx_time_tmp, idx_lat_tmp[0]:idx_lat_tmp[-1]+1, idx_lon_tmp[0]:idx_lon_tmp[-1]+1]
if len(idx_lon_tmp) < PatchSize + 1: # Close to the longitude transition
idx_lon_side = np.where(abs(lon+np.sign(lon_c)*360 - lon_c) <= PatchSize/2*Resolution)[0]
Patch_side = data[idx_time_tmp, idx_lat_tmp[0]:idx_lat_tmp[-1]+1, idx_lon_side[0]:idx_lon_side[-1]+1]
if lon_c < 0: # Close to the west side. Concatenate the side patch to left
Patch_tmp = np.concatenate((Patch_side, Patch_tmp), axis=2)
elif lon_c > 0:
Patch_tmp = np.concatenate((Patch_tmp, Patch_side), axis=2)
if len(idx_lat_tmp) < PatchSize + 1: # Close to the poles
num_nanrow = PatchSize + 1 - len(idx_lat_tmp)
Patch_nan = np.zeros((1, num_nanrow, PatchSize + 1)) * np.nan
if lat_c > 0:
Patch_tmp = np.concatenate((Patch_nan, Patch_tmp), axis=1)
elif lat_c < 0:
Patch_tmp = np.concatenate((Patch_tmp, Patch_nan), axis=1)
# For the normal case, we remove the left and top line --> CP is the left top corner
# # 对于正常情况,去掉patch的第一行和第一列。
Patch_tmp = Patch_tmp[:, 1:, 1:]
# Fill the NaNs using the mean of the patch
Patch_mean = np.nanmean(Patch_tmp)
if ~np.isnan(Patch_mean):
Patch_tmp[np.isnan(Patch_tmp)] = Patch_mean
else:
Patch_tmp[np.isnan(Patch_tmp)] = 0
Sample_tmp[i_sample, :, :] = Patch_tmp ##将处理后的patch存入样本矩阵的对应位置。
return Sample_tmp
def genSampleMat_TemporalInvariant(data, idx_sample_tmp):
TimePosition_tmp = TimePosition[idx_sample_tmp, :]
Sample_tmp = np.zeros((len(idx_sample_tmp), PatchSize, PatchSize))
if np.mod(PatchSize, 2) == 0: # Even size, the center is defined as the left up corner of the center 4 pixels
for i_sample in range(0, len(idx_sample_tmp)):
lat_c = TimePosition_tmp[i_sample, 1]
lon_c = TimePosition_tmp[i_sample, 2]
idx_lat_tmp = np.where(abs(lat - lat_c) <= PatchSize/2*Resolution)[0]
idx_lon_tmp = np.where(abs(lon - lon_c) <= PatchSize/2*Resolution)[0]
# We first concatenate longitude, then latitude is just nan with fixed longitutde size
Patch_tmp = data[idx_lat_tmp[0]:idx_lat_tmp[-1]+1, idx_lon_tmp[0]:idx_lon_tmp[-1]+1]
if len(idx_lon_tmp) < PatchSize + 1: # Close to the longitude transition
idx_lon_side = np.where(abs(lon+np.sign(lon_c)*360 - lon_c) <= PatchSize/2*Resolution)[0]
Patch_side = data[idx_lat_tmp[0]:idx_lat_tmp[-1]+1, idx_lon_side[0]:idx_lon_side[-1]+1]
if lon_c < 0: # Close to the west side. Concatenate the side patch to left
Patch_tmp = np.concatenate((Patch_side, Patch_tmp), axis=1)
elif lon_c > 0:
Patch_tmp = np.concatenate((Patch_tmp, Patch_side), axis=1)
if len(idx_lat_tmp) < PatchSize + 1: # Close to the poles
num_nanrow = PatchSize + 1 - len(idx_lat_tmp)
Patch_nan = np.zeros((num_nanrow, PatchSize + 1)) * np.nan
if lat_c > 0:
Patch_tmp = np.concatenate((Patch_nan, Patch_tmp), axis=0)
elif lat_c < 0:
Patch_tmp = np.concatenate((Patch_tmp, Patch_nan), axis=0)
# For the normal case, we remove the left and top line --> CP is the left top corner
Patch_tmp = Patch_tmp[1:, 1:]
# Fill the NaNs using the mean of the patch
Patch_mean = np.nanmean(Patch_tmp)
if ~np.isnan(Patch_mean):
Patch_tmp[np.isnan(Patch_tmp)] = Patch_mean
else:
Patch_tmp[np.isnan(Patch_tmp)] = 0
Sample_tmp[i_sample, :, :] = Patch_tmp
return Sample_tmp
# %% Input parameters
datapath = "D:/4_P emission inventory/2_CNN downscaling/0p1res/"
outpath = "D:/4_P emission inventory/2_CNN downscaling/output/TFs/"
NumBatch = 4
BatchSize = 512
Resolution = 0.1 #输出数据的分辨率,还是输入分辨率???
PatchSize = 32
Region = "Global"
LatMax = 90
LatMin = -90
LonMax = 180
LonMin = -180
# %% Load data
def load_tif(file_path):
dataset = gdal.Open(file_path)
if dataset is None:
print(f"Failed to open file: {file_path}")
return None
data = dataset.ReadAsArray()
dataset = None
return data
def load_and_stack_data(datapath, start_year, end_year):
stacked_data = {}
for year in range(start_year, end_year + 1):
for month in range(1, 13):
rf_predict_path = os.path.join(datapath, f"rf_predict_0p1/rf_{year}{month:02d}.tif")
geomschem_path = os.path.join(datapath, f"geomschem_res_0p1/{year}{month:02d}.tif")
rf_predict_data = load_tif(rf_predict_path)
geomschem_data = load_tif(geomschem_path)
if 'rf_predict' not in stacked_data:
stacked_data['rf_predict'] = []
if 'geomschem' not in stacked_data:
stacked_data['geomschem'] = []
stacked_data['rf_predict'].append(rf_predict_data)
stacked_data['geomschem'].append(geomschem_data)
for variable in stacked_data:
stacked_data[variable] = np.stack(stacked_data[variable], axis=-1)
return stacked_data
# 示例用法
datapath = "D:/4_P emission inventory/2_CNN downscaling/0p1res/"
start_year = 2004
end_year = 2019
data = load_and_stack_data(datapath, start_year, end_year)
# 调整数据形状从维度从 (lat, lon, time) 调整为 (time, lat, lon)
rf_predict = np.moveaxis(data['rf_predict'], [0, 1, 2], [1,2,0])
geomschem = np.moveaxis(data['geomschem'], [0, 1, 2], [1, 2, 0])
from astropy.time import Time
import numpy as np
start_date = '2004-01-01'
end_date = '2019-12-31'
# 生成起始日期和终止日期之间的时间序列,频率为月
# 将起始日期和终止日期转换为时间对象
time_range = Time([start_date, end_date], format='iso', scale='utc')
# 将时间对象转换为MJD
mjd_start, mjd_end = time_range.mjd
# 在起始MJD和结束MJD之间生成等间隔的MJD值
num_steps = 192 # 根据需要设置时间步数
MJD = np.linspace(mjd_start, mjd_end, num_steps)
num_time,num_lat,num_lon= geomschem.shape #(192, 72, 144) ((192, 1800, 3600))
# lat = np.arange(num_lat)
# lon = np.arange(num_lon)
# lon_grid, lat_grid= np.meshgrid(lon, lat)
maskpath = "D:/4_P emission inventory/2_CNN downscaling/"
LandMask = np.genfromtxt((maskpath + "landmask.csv"), delimiter=',') # For the basin-wise model Amazonas
print(LandMask.shape) #(72, 144)
rf_predict[:, LandMask==0] = np.NaN
geomschem[:, LandMask==0] = np.NaN
lat = np.arange(90, -90, -0.1)
lon = np.arange(-180, 180, 0.1)
print(lat.shape) #(72,)
print(lon.shape) #(144,)
lon_grid, lat_grid= np.meshgrid(lon,lat)
print(lon_grid.shape) #(72, 144)
print(lat_grid.shape) #(72, 144)
# %% Get the valid data
# idx_lon = np.where(np.logical_and(lon_grid >= LonMin, lon_grid <= LonMax))
# idx_lat = np.where(np.logical_and(lat_grid >= LatMin, lat_grid <= LatMax))
# data_in = geomschem[0, idx_lat, idx_lon]
#识出位于指定经纬度范围内的网格点。
grid_in = np.logical_and(np.logical_and(lon_grid >= LonMin, lon_grid <= LonMax),\
np.logical_and(lat_grid >= LatMin, lat_grid <= LatMax))
# grid_in.shape:(72, 144)
data_in = geomschem[0, grid_in] #geomschem是三位数据,shape为(192, 72, 144)
lon_in = lon_grid[grid_in]
lat_in = lat_grid[grid_in]
# Compute number of samples, get shuffle index 计算样本数量并获取有效像素点的经纬度:
num_date = len(MJD)
# num_date = geomschem.shape[0]
idx_val = ~np.isnan(data_in)#np.isnan() 函数检查 data_in 数组中的 NaN 值(对应R的is.na)。~ 运算符对布尔数组进行取反操作,将 NaN 值标记为 False,非 NaN 值标记为 True。idx_val 是一个布尔掩码,标识了 data_in 中的有效值位置。
# idx_val.size 10368
lat_val = lat_in[idx_val]
lon_val = lon_in[idx_val] #lat_val 和 lon_val 分别表示有效像素的纬度和经度坐标。
num_pixel = np.sum(idx_val) #计算 idx_val 中 True 值的数量,而非所有值数量
num_sample = num_date * num_pixel
num_file = np.ceil(num_sample / (NumBatch*BatchSize)).astype(np.int64)
# Mask out the non-land pixels
TimePosition = np.vstack((np.repeat(MJD, num_pixel),\
np.tile(lat_val, num_date),\
np.tile(lon_val, num_date))).T
# 原始数据中提取研究区域内的有效数据点,计算样本数量,创建时间-位置矩阵,并生成随机顺序的样本索引
np.random.seed(1996)
idx_sample_shuffle = np.arange(num_sample)
np.random.shuffle(idx_sample_shuffle)
# %% Robust normalization based on percentiles
# 归一化处理:4-sigma 归一化" 或 "99.99% 归一化
# 计算每个特征的最大值和最小值:使用变量的99.99百分位数作为最大值,0.01 百分位数作为最小值
feature_max_list = [np.nanpercentile(geomschem[:, grid_in], 99.00),
np.nanpercentile(rf_predict[:, grid_in], 99.99),
180]#对于纬度和经度特征,最大值和最小值分别设置为 90/-90 和 180/-180。
feature_min_list = [np.nanpercentile(geomschem[:, grid_in], 1.00),
np.nanpercentile(rf_predict[:, grid_in], 0.01),
-90,
-180]
# 保存归一化指数
np.savetxt(outpath + Region + "_Scaler_max.csv", np.array(feature_max_list), delimiter=",")
np.savetxt(outpath + Region + "_Scaler_min.csv", np.array(feature_min_list), delimiter=",")
# 对每个特征进行归一化,归一化公式为: (特征值 - 最小值) / (最大值 - 最小值)
feature_geomschem = (geomschem - feature_min_list[0]) / (feature_max_list[0] - feature_min_list[0])
feature_rf_predict = (rf_predict - feature_min_list[1]) / (feature_max_list[1] - feature_min_list[1])
feature_lat = (lat_grid - feature_min_list[2]) / (feature_max_list[2] - feature_min_list[2])
feature_lon = (lon_grid - feature_min_list[3]) / (feature_max_list[3] - feature_min_list[3])
# %%
# Compute the number of samples in one file
for i_file in tqdm(range(num_file)):
# 计算每个 TFRecord 文件中的样本数
# 对于每个 TFRecord 文件,计算其起始样本索引 idx_st 和结束样本索引 idx_ed。
# 如果是最后一个 TFRecord 文件,样本数 num_SampleInFile 为剩余样本数。
# 否则,样本数为 BatchSize*NumBatch。
if i_file == num_file-1: # If it is the last TFRecord
idx_st = i_file*BatchSize*NumBatch
num_SampleInFile = num_sample - i_file*BatchSize*NumBatch
idx_ed = idx_st + num_SampleInFile
else:
idx_st = i_file*BatchSize*NumBatch
num_SampleInFile = BatchSize*NumBatch
idx_ed = idx_st + num_SampleInFile
# 生成样本
idx_sample_tmp = idx_sample_shuffle[idx_st:idx_ed]
# For geomschem, we need both normalized and unnormalized data
Sample_geomschem_ori = genSampleMat(geomschem, idx_sample_tmp)
Sample_geomschem = genSampleMat(feature_geomschem, idx_sample_tmp)
# For the others, we just need normalized data
Sample_rf_predict = genSampleMat(feature_rf_predict, idx_sample_tmp)
Sample_Lat = genSampleMat_TemporalInvariant(feature_lat, idx_sample_tmp)
Sample_Lon = genSampleMat_TemporalInvariant(feature_lon, idx_sample_tmp)
# Set the file name
if i_file < 9:
fname = outpath + Region + "_000" + str(i_file+1) + ".tfrecords"
elif i_file >= 9 and i_file < 99:
fname = outpath + Region + "_00" + str(i_file+1) + ".tfrecords"
elif i_file >= 99 and i_file < 999:
fname = outpath + Region + "_0" + str(i_file+1) + ".tfrecords"
else:
fname = outpath + Region + "_" + str(i_file+1) + ".tfrecords"
# Write the data to TFRecord
with tf.io.TFRecordWriter(fname) as file_writer:
for i_sample in range(num_SampleInFile):
record_bytes = tf.train.Example(features=tf.train.Features(feature={
"geomschem_ori": tf.train.Feature(float_list=tf.train.FloatList(value=Sample_geomschem_ori[i_sample, :, :].reshape(-1))),
"geomschem": tf.train.Feature(float_list=tf.train.FloatList(value=Sample_geomschem[i_sample, :, :].reshape(-1))),
"rf_predict": tf.train.Feature(float_list=tf.train.FloatList(value=Sample_rf_predict[i_sample, :, :].reshape(-1))),
"Lat": tf.train.Feature(float_list=tf.train.FloatList(value=Sample_Lat[i_sample, :, :].reshape(-1))),
"Lon": tf.train.Feature(float_list=tf.train.FloatList(value=Sample_Lon[i_sample, :, :].reshape(-1))),
"Shape": tf.train.Feature(int64_list = tf.train.Int64List(value=Sample_geomschem[i_sample, :, :].shape))
})).SerializeToString()
file_writer.write(record_bytes)
file_writer.close()
#报错
Sample_geomschem_ori = genSampleMat(geomschem, idx_sample_tmp)
C:\Windows\TEMP/ipykernel_16548/628079931.py:43: RuntimeWarning: Mean of empty slice
Patch_mean = np.nanmean(Patch_tmp)
Traceback (most recent call last):
File "C:\Windows\TEMP/ipykernel_16548/3737866055.py", line 1, in <module>
Sample_geomschem_ori = genSampleMat(geomschem, idx_sample_tmp)
File "C:\Windows\TEMP/ipykernel_16548/628079931.py", line 49, in genSampleMat
Sample_tmp[i_sample, :, :] = Patch_tmp ##将处理后的patch存入样本矩阵的对应位置。
ValueError: could not broadcast input array from shape (32,31) into shape (32,32)
检查了一下,是在第一个函数def genSampleMat(data, idx_sample_tmp)的地方
def genSampleMat(data, idx_sample_tmp):
TimePosition_tmp = TimePosition[idx_sample_tmp, :] ##获取原始数据的时间步数,假设时间维度是第0维。
Sample_tmp = np.zeros((len(idx_sample_tmp), PatchSize, PatchSize))##从TimePosition矩阵中选取对应的时空位置信息
if np.mod(PatchSize, 2) == 0: # # 检查PatchSize是否为偶数。如果是,则中心像素定义为中心4个像素的左上角。
for i_sample in range(0, len(idx_sample_tmp)):
idx_time_tmp = MJD == TimePosition_tmp[i_sample, 0] ## 从时空位置信息TimePosition_tmp中提取Modified Julian Date (MJD)。 #找到当前样本点对应的时间步索引。
# #获取经纬度
lat_c = TimePosition_tmp[i_sample, 1]
lon_c = TimePosition_tmp[i_sample, 2]
# Double-side difference --> Case odd number of pixels with clear defined central pixel
# # 根据样本点的经纬度和PatchSize及Resolution,确定需要从原始数据中提取的patch的纬度和经度范围
idx_lat_tmp = np.where(abs(lat - lat_c) <= PatchSize/2*Resolution)[0]
idx_lon_tmp = np.where(abs(lon - lon_c) <= PatchSize/2*Resolution)[0]
# We first concatenate longitude, then latitude is just nan with fixed longitutde size
# 从原始数据中提取patchWe first concatenate longitude, then latitude is just nan with fixed longitutde size
Patch_tmp = data[idx_time_tmp, idx_lat_tmp[0]:idx_lat_tmp[-1]+1, idx_lon_tmp[0]:idx_lon_tmp[-1]+1]
Patch_tmp = data[idx_time_tmp, idx_lat_tmp[0]:idx_lat_tmp[-1]+1, idx_lon_tmp[0]:idx_lon_tmp[-1]+1]
Traceback (most recent call last):
File "C:\Windows\TEMP/ipykernel_16548/2507885614.py", line 1, in <module>
Patch_tmp = data[idx_time_tmp, idx_lat_tmp[0]:idx_lat_tmp[-1]+1, idx_lon_tmp[0]:idx_lon_tmp[-1]+1]
TypeError: unhashable type: 'numpy.ndarray'