Sat- nerf深度损失

发布于:2025-03-04 ⋅ 阅读:(19) ⋅ 点赞:(0)

首先损失函数定义在metrics.py,代码如下:

class DepthLoss(torch.nn.Module):
    def __init__(self, lambda_ds=1.0):
        super().__init__()
        # 初始化lambda_ds参数,用于调节深度损失的权重,并且将其缩小为原来的1/3
        self.lambda_ds = lambda_ds / 3.
        # 初始化均方误差损失函数(MSELoss),并设置reduce=False表示不对损失值进行平均
        self.loss = torch.nn.MSELoss(reduce=False)

    def forward(self, inputs, targets, weights=1.):
        # 创建一个字典,用来存储不同类型的深度损失
        loss_dict = {}

        # 默认使用'coarse'(粗略)类型
        typ = 'coarse'
        # 计算输入深度与目标深度之间的损失,并将结果存入字典
        loss_dict[f'{typ}_ds'] = self.loss(inputs['depth_coarse'], targets)

        # 如果输入中包含'fine'(精细)类型的深度数据
        if 'depth_fine' in inputs:
            typ = 'fine'
            # 计算精细深度的损失,并存入字典
            loss_dict[f'{typ}_ds'] = self.loss(inputs['depth_fine'], targets)

        # 对每个损失项应用权重
        for k in loss_dict.keys():
            # 计算加权的平均损失,并乘以lambda_ds来调整损失的权重
            loss_dict[k] = self.lambda_ds * torch.mean(weights * loss_dict[k])

        # 计算所有损失项的总和
        loss = sum(l for l in loss_dict.values())

        # 返回总损失以及包含各个深度损失的字典
        return loss, loss_dict

需要三个输入inputs, targets, weights=1(inputs为输入深度,target为gt,weight为权重),得到两个输出loss, loss_dict(loss为总和,loss_dict为记录单个损失的字典)。

当main函数运行到

system = NeRF_pl(args)  # 初始化 NeRF 模型系统,传入配置参数,为模型训练做好准备工作,确保所有需要的配置和资源都已经到位。

开始调用Nerf_pl 的_init_,会在其中实例化 DepthLoss 类:

self.depth_loss = DepthLoss(lambda_ds=args.ds_lambda)  # 初始化深度损失对象,传入深度监督系数

当运行trainer.fit(system)时,训练启动:

当执行trainer.fit(system)时,Lightning接管了训练过程,
Lightning首先调用prepare_data()准备数据集
然后调用configure_optimizers()设置优化器和学习率调度器

训练循环:

Lightning自动开始训练循环,每个epoch包含:

训练步骤: Lightning自动从train_dataloader()加载数据:

    def train_dataloader(self):
        """创建并返回训练数据加载器字典

        根据配置参数创建不同模态(颜色/深度)的训练数据加载器。当self.depth为True时,
        会同时创建颜色数据和深度数据的加载器。数据加载器使用4个工作进程进行数据加载,
        启用内存锁页(pin_memory)以加速GPU数据传输,并自动进行批次数据打乱。

        Returns:
            dict: 包含数据加载器的字典,键为模态名称("color"/"depth"),
                值为对应的torch.utils.data.DataLoader实例
        """
        # 创建颜色数据的训练集加载器(第一个数据集)
        a = DataLoader(self.train_dataset[0],
                       shuffle=True,
                       num_workers=4,
                       batch_size=self.args.batch_size,
                       pin_memory=True)
        loaders = {"color": a}

        # 当需要加载深度数据时,创建第二个数据加载器
        if self.depth:
            b = DataLoader(self.train_dataset[1],#数据从上面dataloade 的self.train_dataset[1],这是一个SatelliteDataset_depth类的实例,在prepare_data()方法中创建
                           shuffle=True,
                           num_workers=4,
                           batch_size=self.args.batch_size,
                           pin_memory=True)
            loaders["depth"] = b#通过batch["depth"]访问

        return loaders

可以看到深度数据从上面prepare_data()方法中创建的self.train_dataset[1],这是一个SatelliteDataset_depth类的实例。
接下来,调用training_step()处理每个批次:

    def training_step(self, batch, batch_nb):
        self.log("lr", train_utils.get_learning_rate(self.optimizer))
        self.train_steps += 1

        rays = batch["color"]["rays"] # (B, 11)
        rgbs = batch["color"]["rgbs"] # (B, 3)
        ts = None if not self.use_ts else batch["color"]["ts"].squeeze() # (B, 1)

        results = self(rays, ts)
        if 'beta_coarse' in results and self.get_current_epoch(self.train_steps) < 2:
            loss, loss_dict = self.loss_without_beta(results, rgbs)
        else:
            loss, loss_dict = self.loss(results, rgbs)
        self.args.noise_std *= 0.9

        if self.depth:
            tmp = self(batch["depth"]["rays"], batch["depth"]["ts"].squeeze())
            kp_depths = torch.flatten(batch["depth"]["depths"][:, 0])
            kp_weights = 1. if self.args.ds_noweights else torch.flatten(batch["depth"]["depths"][:, 1])
            loss_depth, tmp = self.depth_loss(tmp, kp_depths, kp_weights)#tmp是作为imput输入进去了,kp_depths是target,kp_weights是权重
            if self.train_steps < self.ds_drop :
                loss += loss_depth
            for k in tmp.keys():
                loss_dict[k] = tmp[k]

        self.log("train/loss", loss)
        typ = "fine" if "rgb_fine" in results else "coarse"

        with torch.no_grad():
            psnr_ = metrics.psnr(results[f"rgb_{typ}"], rgbs)
            self.log("train/psnr", psnr_)
        for k in loss_dict.keys():
            self.log("train/{}".format(k), loss_dict[k])

        self.log('train_psnr', psnr_, on_step=True, on_epoch=True, prog_bar=True)
        return {'loss': loss}

而在这中间 results = self(rays, ts)相当于就是隐式调用了Nerf_pl的forward算法,另外我们可以看到我们的目标self.depth_loss被调用了,tmp是作为imput输入进去了,kp_depths是target,kp_weights是权重。
那接下来我们一个一个分析这三个输入都是从哪来的。

预测深度(tmp)的完整产生过程

1. 初始数据:深度射线信息

# 在training_step中获取的初始数据,来自于train_dataloader(),
rays = batch["depth"]["rays"]  # 形状为[B, 11]的张量
ts = batch["depth"]["ts"].squeeze()  # 时间戳信息

这里的rays包含11个通道的信息:

  • rays[:, 0:3]: 射线原点坐标
  • rays[:, 3:6]: 射线方向向量
  • rays[:, 6:7]: 近平面距离
  • rays[:, 7:8]: 远平面距离
  • rays[:, 8:11]: 太阳方向向量

2. 模型前向传递:self方法调用

# training_step中的调用
tmp = self(batch["depth"]["rays"], batch["depth"]["ts"].squeeze())

这调用了NeRF_plforward方法(在main.py中):

def forward(self, rays, ts):
    chunk_size = self.args.chunk
    batch_size = rays.shape[0]
    results = defaultdict(list)
    
    for i in range(0, batch_size, chunk_size):
        rendered_ray_chunks = render_rays(self.models, self.args, 
                                          rays[i:i + chunk_size],
                                          ts[i:i + chunk_size] if ts is not None else None)
        
        for k, v in rendered_ray_chunks.items():
            results[k] += [v]
            
    for k, v in results.items():
        results[k] = torch.cat(v, 0)
        
    return results

该方法将射线分成小块,处理后合并结果。

3. 在forward中调用了render_rays函数:渲染逻辑

进入rendering.py中的render_rays函数:

def render_rays(models, args, rays, ts):
    # 获取配置参数
    N_samples = args.n_samples  
    N_importance = args.n_importance
    variant = args.model  # "sat-nerf"
    
    # 分解射线信息
    rays_o, rays_d = rays[:, 0:3], rays[:, 3:6]
    near, far = rays[:, 6:7], rays[:, 7:8]
    
    # 采样深度点
    z_steps = torch.linspace(0, 1, N_samples, device=rays.device)
    z_vals = near * (1-z_steps) + far * z_steps  # 线性采样
    
    # 添加随机扰动
    if perturb > 0:
        # 采样点添加扰动代码...
    
    # 计算3D采样点坐标
    xyz_coarse = rays_o.unsqueeze(1) + rays_d.unsqueeze(1) * z_vals.unsqueeze(2)
    
    # 根据模型类型调用相应的inference函数
    if variant == "sat-nerf":
        from models.satnerf import inference
        sun_d = rays[:, 8:11]
        rays_t = models['t'](ts) if ts is not None else None
        result = inference(models["coarse"], args, xyz_coarse, z_vals, 
                           rays_d=None, sun_d=sun_d, rays_t=rays_t)
        # 太阳光校正相关代码...
    
    # 组织结果
    result_ = {}
    for k in result.keys():
        result_[f"{k}_coarse"] = result[k]
    
    # 如果需要精细采样
    if N_importance > 0:
        # 精细采样相关代码...
    
    return result_

这个函数处理射线,生成采样点,并调用合适的模型inference函数。

4. 又在render_rays中调用satnerf的inference函数

进入models/satnerf.pyinference函数:

def inference(model, args, rays_xyz, z_vals, rays_d=None, sun_d=None, rays_t=None):
    N_rays = rays_xyz.shape[0]
    N_samples = rays_xyz.shape[1]
    xyz_ = rays_xyz.view(-1, 3)  # 展平为[N_rays*N_samples, 3]
    
    # 处理额外输入
    rays_d_ = None if rays_d is None else torch.repeat_interleave(rays_d, repeats=N_samples, dim=0)
    sun_d_ = None if sun_d is None else torch.repeat_interleave(sun_d, repeats=N_samples, dim=0)
    rays_t_ = None if rays_t is None else torch.repeat_interleave(rays_t, repeats=N_samples, dim=0)
    
    # 分块运行NeRF模型
    chunk = args.chunk
    batch_size = xyz_.shape[0]
    
    out_chunks = []
    for i in range(0, batch_size, chunk):
        out_chunks += [model(xyz_[i:i+chunk],
                           input_dir=None if rays_d_ is None else rays_d_[i:i + chunk],
                           input_sun_dir=None if sun_d_ is None else sun_d_[i:i + chunk],
                           input_t=None if rays_t_ is None else rays_t_[i:i + chunk])]
    out = torch.cat(out_chunks, 0)
    
    # 处理输出
    out = out.view(N_rays, N_samples, model.number_of_outputs)
    rgbs = out[..., :3]       # 颜色
    sigmas = out[..., 3]      # 体密度
    # 其他输出处理...
    
    # 计算alpha合成权重
    deltas = z_vals[:, 1:] - z_vals[:, :-1]
    delta_inf = 1e10 * torch.ones_like(deltas[:, :1])
    deltas = torch.cat([deltas, delta_inf], -1)
    
    noise_std = args.noise_std
    noise = torch.randn(sigmas.shape, device=sigmas.device) * noise_std
    alphas = 1 - torch.exp(-deltas * torch.relu(sigmas + noise))
    alphas_shifted = torch.cat([torch.ones_like(alphas[:, :1]), 1 - alphas + 1e-10], -1)
    transparency = torch.cumprod(alphas_shifted, -1)[:, :-1]
    weights = alphas * transparency
    
    # 计算深度
    depth_final = torch.sum(weights * z_vals, -1)  # 关键:计算加权平均深度
    
    # 组织返回结果
    result = {'rgb': rgb_final,
              'depth': depth_final,  # 这是最终的预测深度
              'weights': weights,
              # 其他输出...
             }
    return result

这个函数调用实际的NeRF模型,处理采样点,并生成体密度和颜色,最后计算深度。

5.inference调用SatNeRF模型

进入satnerf.py中的SatNeRF类的forward方法:

def forward(self, input_xyz, input_dir=None, input_sun_dir=None, input_t=None, sigma_only=False):
    # 将输入坐标送入映射层
    input_xyz = self.mapping[0](input_xyz)
    
    # 通过全连接网络处理
    xyz_ = input_xyz
    for i in range(self.layers):
        if i in self.skips:
            xyz_ = torch.cat([input_xyz, xyz_], -1)
        xyz_ = self.fc_net[2*i](xyz_)
        xyz_ = self.fc_net[2*i + 1](xyz_)
    
    # 获取共享特征
    shared_features = xyz_
    
    # 预测体密度(sigma)
    sigma = self.sigma_from_xyz(shared_features)
    if sigma_only:
        return sigma
    
    # 预测颜色和其他属性
    xyz_features = self.feats_from_xyz(shared_features)
    # RGB预测...
    
    # 太阳可见度和天空颜色预测...
    
    # 预测不确定度参数beta
    input_for_beta = torch.cat([xyz_features, input_t], -1)
    beta = self.beta_from_xyz(input_for_beta)
    
    # 组合所有输出
    out = torch.cat([rgb, sigma, sun_v, sky_color, beta], 1)
    
    return out

这个神经网络预测每个点的体密度、颜色和其他属性。

6. 从体密度到深度的转换

回到inference函数中,关键的深度计算步骤:

# 计算alpha合成权重
alphas = 1 - torch.exp(-deltas * torch.relu(sigmas + noise))
# ...处理alphas...
weights = alphas * transparency  # 体密度权重

# 预测深度:沿射线的加权平均深度
depth_final = torch.sum(weights * z_vals, -1)  # [N_rays]

体密度通过alpha合成转换为权重,然后用这些权重计算加权平均深度。

7. 结果整合与返回

最终在render_rays函数中:

result_ = {}
for k in result.keys():
    result_[f"{k}_coarse"] = result[k]

# 深度结果作为'depth_coarse'返回

这就形成了tmp中的'depth_coarse'键,即粗略深度预测。

完整数据流向总结

  1. 初始数据: batch["depth"]["rays"]batch["depth"]["ts"] (来自数据加载器)
  2. NeRF_pl.forward: 将射线分块并调用render_rays
  3. render_rays: 处理射线,生成采样点,调用satnerf.inference
  4. satnerf.inference:
    • 处理采样点
    • 调用SatNeRF模型预测体密度
    • 将体密度转换为权重
    • 计算加权平均深度
  5. SatNeRF.forward: 神经网络计算每个点的体密度
  6. 结果整合: 深度值被组织为'depth_coarse'键返回

这就是tmp中预测深度的完整产生过程,从最初的射线输入到最终的深度预测值。

SatNeRF中真实深度(kp_depths)的完整产生过程

下面我详细梳理从原始数据到最终在training_step中使用的kp_depths的完整生成过程,按照数据流向逐步说明:

1. 原始数据:卫星图像和元数据

最初的原始数据包括:

  • 卫星图像(RGB图像)
  • JSON元数据文件,包含相机参数(RPC)和关键点信息

这些数据存储在args.root_dirargs.img_dir指定的目录中。

2. 数据集创建:load_dataset函数

prepare_data方法中通过调用load_dataset函数加载数据:

# main.py中的prepare_data方法
def prepare_data(self):
    self.train_dataset = [] + load_dataset(self.args, split="train")
    self.val_dataset = [] + load_dataset(self.args, split="val")

load_dataset函数(在datasets/init.py中)会根据数据类型创建不同的数据集:

def load_dataset(args, split):
    if args.data == 'sat':
        ds_list = []
        from .satellite import SatelliteDataset
        ds_list.append(SatelliteDataset(args.root_dir, args.img_dir, split, img_downscale=args.img_downscale, cache_dir=args.cache_dir))
        
        if split == 'train' and args.ds_lambda > 0:
            from .satellite_depth import SatelliteDataset_depth
            ds_list.append(SatelliteDataset_depth(args.root_dir, args.img_dir, split, img_downscale=args.img_downscale, cache_dir=args.cache_dir))
        
        return ds_list
    # ...

args.ds_lambda > 0时,会创建SatelliteDataset_depth类的实例作为深度数据集。

3. 深度数据集初始化:SatelliteDataset_depth

SatelliteDataset_depth__init__方法中处理关键点数据:

# datasets/satellite_depth.py
def __init__(self, root_dir, img_dir, split='train', img_downscale=1.0, cache_dir=None):
    super().__init__(root_dir, img_dir, split, img_downscale, cache_dir)

4. 如果split == “train”,调用load_train_split()方法

tie_points是通过多视图几何(MVG)从多张卫星图像中三角测量得到的3D点坐标。

def load_train_split(self):
    with open(os.path.join(self.json_dir, "train.txt"), "r") as f:
        json_files = f.read().split("\n")
    self.json_files = [os.path.join(self.json_dir, json_p) for json_p in json_files]
    if os.path.exists(self.json_dir + "/pts3d.npy"):
        self.tie_points = np.load(self.json_dir + "/pts3d.npy")
        self.all_rays, self.all_depths, self.all_ids = self.load_depth_data(self.json_files, self.tie_points, verbose=True)
    else:
        raise FileNotFoundError("Could not find {}".format(self.json_dir + "/pts3d.npy"))

读取训练图像列表
加载3D关键点(tie_points)
调用load_depth_data方法处理深度数据

5.从4中调用load_depth_data处理深度数据

load_depth_data方法是生成真实深度数据的核心:

def load_depth_data(self, json_files, tie_points, verbose=False):
    all_rays, all_depths, all_sun_dirs, all_weights = [], [], [], []
    all_ids = []
    kp_weights = self.load_keypoint_weights_for_depth_supervision(json_files, tie_points)

    for t, json_p in enumerate(json_files):
        # 读取JSON数据
        d = sat_utils.read_dict_from_json(json_p)
        img_id = sat_utils.get_file_id(d["img"])

        # 获取关键点信息
        pts2d = np.array(d["keypoints"]["2d_coordinates"])/ self.img_downscale
        pts3d = np.array(tie_points[d["keypoints"]["pts3d_indices"], :])
        rpc = sat_utils.rescale_rpc(rpcm.RPCModel(d["rpc"], dict_format="rpcm"), 1.0 / self.img_downscale)

        # 生成射线
        cols, rows = pts2d.T
        min_alt, max_alt = float(d["min_alt"]), float(d["max_alt"])
        rays = get_rays(cols, rows, rpc, min_alt, max_alt)
        rays = self.normalize_rays(rays)
        all_rays += [rays]

        # 获取太阳方向
        sun_dirs = self.get_sun_dirs(float(d["sun_elevation"]), float(d["sun_azimuth"]), rays.shape[0])
        all_sun_dirs += [sun_dirs]

        # 标准化3D坐标
        pts3d = torch.from_numpy(pts3d).type(torch.FloatTensor)
        pts3d[:, 0] -= self.center[0]
        pts3d[:, 1] -= self.center[1]
        pts3d[:, 2] -= self.center[2]
        pts3d[:, 0] /= self.range
        pts3d[:, 1] /= self.range
        pts3d[:, 2] /= self.range

        # 计算深度值
        depths = torch.linalg.norm(pts3d - rays[:, :3], axis=1)
        all_depths += [depths[:, np.newaxis]]
        
        # 获取权重
        current_weights = torch.from_numpy(kp_weights[d["keypoints"]["pts3d_indices"]]).type(torch.FloatTensor)
        all_weights += [current_weights[:, np.newaxis]]
        
        all_ids += [t * torch.ones(rays.shape[0], 1)]

    # 组合所有数据
    all_ids = torch.cat(all_ids, 0)
    all_rays = torch.cat(all_rays, 0)  # (len(json_files)*h*w, 8)
    all_depths = torch.cat(all_depths, 0)  # (len(json_files)*h*w, 1)
    all_weights = torch.cat(all_weights, 0)
    all_depths = torch.hstack([all_depths, all_weights])  # 深度和权重合并
    all_sun_dirs = torch.cat(all_sun_dirs, 0)  # (len(json_files)*h*w, 3)
    all_rays = torch.hstack([all_rays, all_sun_dirs])  # (len(json_files)*h*w, 11)
    
    return all_rays, all_depths, all_ids

该方法处理每个JSON文件的关键点,计算射线和深度值,并整合成训练数据。

6. 在5中调用了计算关键点权重:load_keypoint_weights_for_depth_supervision方法

def load_keypoint_weights_for_depth_supervision(self, json_files, tie_points):
    # 初始化权重数组
    kp_weights = np.ones(len(tie_points))
    
    # 收集所有关键点的2D-3D对应关系
    all_obs = {}
    for json_p in json_files:
        with open(json_p) as f:
            d = json.load(f)
        
        if "keypoints" in d.keys():
            # 获取RPC模型
            rpc = rpcm.RPCModel(d["rpc"], dict_format="rpcm")
            # 获取2D坐标和对应的3D索引
            pts2d = np.array(d["keypoints"]["2d_coordinates"]) / self.img_downscale
            pts3d_indices = d["keypoints"]["pts3d_indices"]
            
            # 收集观察
            for i, idx in enumerate(pts3d_indices):
                if idx not in all_obs:
                    all_obs[idx] = []
                all_obs[idx].append((json_p, pts2d[i], rpc))
    
    # 计算每个3D点的重投影误差
    for idx, obs_list in all_obs.items():
        if len(obs_list) >= 2:  # 至少需要2个观察
            # 计算重投影误差
            reproj_err = compute_reprojection_error(tie_points[idx], obs_list)
            # 根据重投影误差设置权重
            kp_weights[idx] = compute_keypoint_weight(reproj_err)
    
    return kp_weights

这个方法计算每个3D关键点的权重,基于其重投影误差。重投影误差越小,权重越大。

7. 数据集的__getitem__方法:准备批次数据

__getitem__方法定义了如何访问数据集中的一个样本:

def __getitem__(self, idx):
    # 获取训练样本
    if self.train:
        sample = {"rays": self.all_rays[idx], "depths": self.all_depths[idx], "ts": self.all_ids[idx].long()}
    else:
        # 验证集处理...
    return sample

对于训练集,直接返回预先计算好的射线、深度和时间戳信息。
如果把数据集想象成一本书:

__init__相当于准备整本书和目录
__len__告诉你书有多少页
__getitem__允许你翻到任意一页并读取内容

DataLoader就像是一个阅读助手,它会按照你指定的顺序(随机或顺序)一次翻几页(批次大小)给你看。

8. 数据加载器:train_dataloader方法

main.pytrain_dataloader方法中创建数据加载器:

def train_dataloader(self):
    a = DataLoader(self.train_dataset[0],
                   shuffle=True,
                   num_workers=4,
                   batch_size=self.args.batch_size,
                   pin_memory=True)
    loaders = {"color": a}
    if self.depth:
        b = DataLoader(self.train_dataset[1],
                       shuffle=True,
                       num_workers=4,
                       batch_size=self.args.batch_size,
                       pin_memory=True)
        loaders["depth"] = b
    return loaders

这个方法创建数据加载器,将SatelliteDataset_depth的数据作为loaders["depth"]返回。

9. 训练步骤:training_step方法

最后,在training_step方法中获取真实深度:

def training_step(self, batch, batch_nb):
    # ...
    
    if self.depth:
        tmp = self(batch["depth"]["rays"], batch["depth"]["ts"].squeeze())
        
        # 获取真实深度数据
        kp_depths = torch.flatten(batch["depth"]["depths"][:, 0])  # 第一列是深度值
        kp_weights = 1. if self.args.ds_noweights else torch.flatten(batch["depth"]["depths"][:, 1])  # 第二列是权重
        
        # 计算深度损失
        loss_depth, tmp = self.depth_loss(tmp, kp_depths, kp_weights)
        # ...

batch["depth"]["depths"][:, 0]就是我们要找的真实深度值kp_depths

完整数据流向总结

初始数据:3D关键点坐标,存储在/pts3d.npy文件中
数据加载:load_train_split方法加载关键点并调用load_depth_data
深度计算:

从JSON文件读取2D对应关系和相机参数
计算相机射线
标准化3D点坐标
计算3D点到射线原点的距离作为深度值

权重计算:基于重投影误差计算每个点的可靠性权重
数据组织:将深度值和权重组合成all_depths张量
批次访问:通过__getitem__方法访问预计算的深度数据
训练使用:在training_step中,kp_depths从batch[“depth”][“depths”][:, 0]提取

回到training_step()

经过

        if self.depth:
            tmp = self(batch["depth"]["rays"], batch["depth"]["ts"].squeeze())
            kp_depths = torch.flatten(batch["depth"]["depths"][:, 0])
            kp_weights = 1. if self.args.ds_noweights else torch.flatten(batch["depth"]["depths"][:, 1])
            loss_depth, tmp = self.depth_loss(tmp, kp_depths, kp_weights)

就已经拿到了最后的损失了。

验证步骤: 定期调用val_dataloader()加载验证数据,并执行validation_step()

前向传播:

在training_step()和validation_step()中,代码调用了self(rays, ts)
这实际上是隐式调用了forward()方法,因为在Python中,当一个类实例被当作函数调用时,会自动调用其__call__方法,而Lightning模型的__call__会调用forward()

钩子函数:

Lightning通过一系列"钩子函数"(如training_step, validation_step等)自动组织训练流程
只需实现这些钩子函数,而不需要手动调用它们