Cesium开发--自定义Primitive实现3D Gaussian Splatting渲染

发布于:2025-02-21 ⋅ 阅读:(14) ⋅ 点赞:(0)

源码地址:gitee github

自定义GaussianSplatPrimitive

GaussianSplatPrimitive类定义

class GaussianSplatPrimitive {
  constructor(options) {

    this.vertexArray = options.vertexArray;
    this.uniformMap = options.uniformMap;
    this.vertexShaderSource = options.vertexShaderSource;
    this.fragmentShaderSource = options.fragmentShaderSource;
    this.renderState = options.renderState;
    this.modelMatrix = options.modelMatrix;
    this.instanceCount = options.instanceCount;
    //this.framebuffer = options.framebuffer;
    this.show = true;
    this.commandToExecute = undefined;
  }

  createCommand(context) {
    let shaderProgram = Cesium.ShaderProgram.fromCache({
      context: context,
      attributeLocations: this.vertexArray.attributes,
      vertexShaderSource: this.vertexShaderSource,
      fragmentShaderSource: this.fragmentShaderSource,
      debugShaders: true,
      logShaderCompilation: true,
    });

    let cachedRenderState = Cesium.RenderState.fromCache(this.renderState);
    return new Cesium.DrawCommand({
      owner: this,
      vertexArray: this.vertexArray,
      primitiveType: Cesium.PrimitiveType.TRIANGLE_FAN, 
      uniformMap: this.uniformMap,
      modelMatrix: this.modelMatrix,
      instanceCount: this.instanceCount,
      shaderProgram: shaderProgram,
      //framebuffer: this.framebuffer,
      renderState: cachedRenderState,
      pass: Cesium.Pass.TRANSLUCENT,
      //pass: Cesium.Pass.OPAQUE
    });
  }

  update(frameState) {
    if (!this.show) {
      return;
    }
    if (!Cesium.defined(this.commandToExecute)) {
      this.commandToExecute = this.createCommand(frameState.context);
    }
    frameState.commandList.push(this.commandToExecute);
  }

  isDestroyed() {
    return false;
  }

  destroy() {
    if (Cesium.defined(this.commandToExecute)) {
      this.commandToExecute.shaderProgram =
        this.commandToExecute.shaderProgram &&
        this.commandToExecute.shaderProgram.destroy();
    }
    return Cesium.destroyObject(this);
  }
}

splat数据获取与处理

获取buffer

可以通过fetch方式获取splat文件的内容buffer。

计算高斯点个数

计算公式:buffer长度/每个高斯点的长度(每行存储一个高斯点数据)。

每行的数据内容和字节占用情况

  • 数据内容及数据类型
    // XYZ - Position (Float32)
    // XYZ - Scale (Float32)
    // RGBA - colors (uint8)
    // IJKL - quaternion/rot (uint8)
  • 数据类型及字节(Byte)和位(Bit)占用
    Float32,单精度浮点数,32位,4字节(每个字节由8位组成)。
    Uint8,8 位无符号整型,8位,1字节。
const RowSizeBytes = 3 * 4 + 3 * 4 + 4 + 4;
const splatCount = buffer.byteLength / RowSizeBytes;

高斯点排序和索引

计算每个点的深度,并按深度排序。
深度值是使用模型视图矩阵与position属性进行计算的。

function runSort(buffer, vertexCount, viewProj) {
  if (!buffer) return;
  if (viewProj.equals(Cesium.Matrix4.IDENTITY)) return;
  
  const f_buffer = new Float32Array(buffer);
  // 初始化最大深度、最小深度和顶点大小列表
  let maxDepth = -Infinity;
  let minDepth = Infinity;
  let sizeList = new Int32Array(vertexCount);
  // 遍历顶点,计算深度并更新最大深度和最小深度
  for (let i = 0; i < vertexCount; i++) {
    let depth =
      ((viewProj[2] * f_buffer[8 * i + 0] +
        viewProj[6] * f_buffer[8 * i + 1] +
        viewProj[10] * f_buffer[8 * i + 2]) *
        4096) |
      0;
    sizeList[i] = depth;
    if (depth > maxDepth) maxDepth = depth;
    if (depth < minDepth) minDepth = depth;
  }

  // This is a 16 bit single-pass counting sort
  let depthInv = (256 * 256) / (maxDepth - minDepth);
  // 初始化计数数组
  let counts0 = new Uint32Array(256 * 256);
  // 遍历顶点,计算索引并更新计数数组
  for (let i = 0; i < vertexCount; i++) {
    sizeList[i] = ((sizeList[i] - minDepth) * depthInv) | 0;
    counts0[sizeList[i]]++;
  }
  // 初始化起始位置数组
  let starts0 = new Uint32Array(256 * 256);
  // 计算起始位置
  for (let i = 1; i < 256 * 256; i++)
    starts0[i] = starts0[i - 1] + counts0[i - 1];
    // 初始化深度索引数组
  depthIndex = new Uint32Array(vertexCount);
  // 遍历顶点,根据计数数组和起始位置数组更新深度索引数组
  for (let i = 0; i < vertexCount; i++) 
      depthIndex[starts0[sizeList[i]]++] = i;
  return [depthIndex, viewProj, vertexCount];
}

参数设置

let gaussianSplatPrimitive = new GaussianSplatPrimitive({
      vertexArray: splatsVertexArray,
      uniformMap: uniformMap,
      modelMatrix: modelMatrix,  
      vertexShaderSource: vertexSource,
      fragmentShaderSource: fragmentSource,
      renderState: splatRenderState,
      instanceCount: vertexCount,
    });

vertexArray

splat顶点

每个高斯点用一个矩形来表示,所以顶点数为4。

const triangleVertices = new Float32Array([-2, -2, 2, -2, 2, 2, -2, 2]);
    const triBuffer = Cesium.Buffer.createVertexBuffer({
      context: InWebGLContext,
      typedArray: triangleVertices,
      usage: Cesium.BufferUsage.DYNAMIC_DRAW,
    });

splat排序索引

cloudIndexPnts为高斯点的排序数据,使用Uint32Array类型。
注意,不是顶点索引。

const orderedSplatIndices = Cesium.Buffer.createVertexBuffer({
      context: InWebGLContext,
      typedArray: cloudIndexPnts,
      usage: Cesium.BufferUsage.DYNAMIC_DRAW,
    });

属性

顶点属性、排序索引属性。

const attributes = [
      {
        index: 0,
        enabled: true,
        vertexBuffer: triBuffer,
        componentsPerAttribute: 2,
        componentDatatype: Cesium.ComponentDatatype.FLOAT,
        normalize: false,
        offsetInBytes: 0,
        strideInBytes: 0,
        instanceDivisor: 0,
      },
      {
        index: 1,
        enabled: true,
        vertexBuffer: orderedSplatIndices, 
        componentsPerAttribute: 1,
        componentDatatype: Cesium.ComponentDatatype.INT,
        normalize: false,
        offsetInBytes: 0,
        strideInBytes: 0,
        bindAsInteger: true, 
        instanceDivisor: 1,
      },
    ];

定义Cesium.VertexArray

 const splatsVertexArray = new Cesium.VertexArray({
      context: InWebGLContext,
      attributes: attributes,
    });

uniformMap

参数定义如代码所示,重点是纹理的生成。

let uniformMap = {
      u_texture: function () {
        return tempSplatTexture;
      },
      focal: function () {
        return new Cesium.Cartesian2(1000, 1000);
      },
    };

生成纹理buffer

将 .splat 文件缓冲区转换为纹理(也许这个纹理文件应该是原生格式,因为它很容易将其加载到 WebGL 中)。
转换过程:
对高斯点进行遍历,对每个点的数据(位置、颜色、旋转、缩放)读取、处理、转换并存放到新的变量中。最终的buffer中,包括位置XYZ、颜色RGBA、协方差。

  • 位置、颜色
    按索引获取和存储数据,无需特别处理。
  • 协方差
    • 获取缩放scale和旋转rotation属性数据,并对rotation进行归一化处理(四元数归一化)。
    • 计算协方差sigma,sigma=R * S(R对应rotation,S对应scale)。
    • 通过packHalf2x16将sigma数据处理成uint类型。

这里用到了数据类型 Float32ArrayUint8ArrayUint32Array

packHalf2x16
将2个32位的小数包装为2个16位的小数然后再包装为1个32位的uint数据。

function generateTexture(buffer, vertexCount) {
  if (!buffer) return;

  const f_buffer = new Float32Array(buffer);
  const u_buffer = new Uint8Array(buffer);

  let texwidth = 1024 * 2; // Set to your desired width
  let texheight = Math.ceil((2 * vertexCount) / texwidth); // Set to your desired height

  let texdata = new Uint32Array(texwidth * texheight * 4); // 4 components per pixel (RGBA)
  let texdata_c = new Uint8Array(texdata.buffer);
  let texdata_f = new Float32Array(texdata.buffer);

  // Here we convert from a .splat file buffer into a texture
  // With a little bit more foresight perhaps this texture file
  // should have been the native format as it'd be very easy to
  // load it into webgl.
  for (let i = 0; i < vertexCount; i++) {
    // x, y, z
    texdata_f[8 * i + 0] = f_buffer[8 * i + 0];
    texdata_f[8 * i + 1] = f_buffer[8 * i + 1];
    texdata_f[8 * i + 2] = f_buffer[8 * i + 2];

    // r, g, b, a
    texdata_c[4 * (8 * i + 7) + 0] = u_buffer[32 * i + 24 + 0];
    texdata_c[4 * (8 * i + 7) + 1] = u_buffer[32 * i + 24 + 1];
    texdata_c[4 * (8 * i + 7) + 2] = u_buffer[32 * i + 24 + 2];
    texdata_c[4 * (8 * i + 7) + 3] = u_buffer[32 * i + 24 + 3];

    // quaternions
    let scale = [
      f_buffer[8 * i + 3 + 0],
      f_buffer[8 * i + 3 + 1],
      f_buffer[8 * i + 3 + 2],
    ];

    let rot = [
      (u_buffer[32 * i + 28 + 0] - 128) / 128,
      (u_buffer[32 * i + 28 + 1] - 128) / 128,
      (u_buffer[32 * i + 28 + 2] - 128) / 128,
      (u_buffer[32 * i + 28 + 3] - 128) / 128,
    ];

    // Compute the matrix product of S and R (M = S * R)
    const M = [
      1.0 - 2.0 * (rot[2] * rot[2] + rot[3] * rot[3]),
      2.0 * (rot[1] * rot[2] + rot[0] * rot[3]),
      2.0 * (rot[1] * rot[3] - rot[0] * rot[2]),

      2.0 * (rot[1] * rot[2] - rot[0] * rot[3]),
      1.0 - 2.0 * (rot[1] * rot[1] + rot[3] * rot[3]),
      2.0 * (rot[2] * rot[3] + rot[0] * rot[1]),

      2.0 * (rot[1] * rot[3] + rot[0] * rot[2]),
      2.0 * (rot[2] * rot[3] - rot[0] * rot[1]),
      1.0 - 2.0 * (rot[1] * rot[1] + rot[2] * rot[2]),
    ].map((k, i) => k * scale[Math.floor(i / 3)]);

    const sigma = [
      M[0] * M[0] + M[3] * M[3] + M[6] * M[6],
      M[0] * M[1] + M[3] * M[4] + M[6] * M[7],
      M[0] * M[2] + M[3] * M[5] + M[6] * M[8],
      M[1] * M[1] + M[4] * M[4] + M[7] * M[7],
      M[1] * M[2] + M[4] * M[5] + M[7] * M[8],
      M[2] * M[2] + M[5] * M[5] + M[8] * M[8],
    ];

    texdata[8 * i + 4] = packHalf2x16(4 * sigma[0], 4 * sigma[1]);
    texdata[8 * i + 5] = packHalf2x16(4 * sigma[2], 4 * sigma[3]);
    texdata[8 * i + 6] = packHalf2x16(4 * sigma[4], 4 * sigma[5]);
  }
  return [texdata, texwidth, texheight];
}

生成Cesium纹理

如果Cesium不支持,需要修改底层源码,例如
Cesium.PixelFormat.RGBA_INTEGER、Cesium.PixelDatatype.UNSIGNED_INT

const tempSplatTexture = new Cesium.Texture({
      context: InWebGLContext,
      width: splatTexWidth,
      height: splatTexHeight,
      pixelFormat: Cesium.PixelFormat.RGBA_INTEGER, 
      pixelDatatype: Cesium.PixelDatatype.UNSIGNED_INT, 
      source: {
        width: splatTexWidth,
        height: splatTexHeight,
        arrayBufferView: texdata,
      },
      flipY: false,
      sampler: new Cesium.Sampler({
        wrapS: Cesium.TextureWrap.CLAMP_TO_EDGE,
        wrapT: Cesium.TextureWrap.CLAMP_TO_EDGE,
        minificationFilter: Cesium.TextureMinificationFilter.NEAREST,
        magnificationFilter: Cesium.TextureMagnificationFilter.NEAREST,
      }),
    });

modelMatrix

模型矩阵变换,例如平移、旋转、缩放,都可以在这里设置。详细参考官方示例或参考我的github示例。

vertexShaderSource

这里使用了cesium内置的uniform,包括czm_modelView、czm_projection、czm_viewport
需要外部传入的有纹理、相机焦距、高斯点位置、索引。

 let vertexShaderSource = `
			#version 300 es
			precision highp float;
			precision highp int;
						
      		uniform highp usampler2D u_texture; 			
			uniform vec2 focal;

			in vec2 position;
			in int index; 

			out vec4 vColor;
			out vec2 vPosition;
      
			void main () {
        		uvec4 cen = texelFetch(u_texture, ivec2((uint(index) & 0x3ffu) << 1, uint(index) >> 10), 0);
				
				vec4 cam = czm_modelView * vec4(uintBitsToFloat(cen.xyz), 1);
				vec4 pos2d = czm_projection * cam; 

				float clip = 1.2 * pos2d.w;
				if (pos2d.z < -clip || pos2d.x < -clip || pos2d.x > clip || pos2d.y < -clip || pos2d.y > clip) {
					gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
					return;
				}

        		uvec4 cov = texelFetch(u_texture, ivec2(((uint(index) & 0x3ffu) << 1) | 1u, uint(index) >> 10), 0);
				vec2 u1 = unpackHalf2x16(cov.x), u2 = unpackHalf2x16(cov.y), u3 = unpackHalf2x16(cov.z);
				mat3 Vrk = mat3(u1.x, u1.y, u2.x, u1.y, u2.y, u3.x, u2.x, u3.x, u3.y);

				mat3 J = mat3(
					focal.x / cam.z, 0., -(focal.x * cam.x) / (cam.z * cam.z), 
					0., -focal.y / cam.z, (focal.y * cam.y) / (cam.z * cam.z), 
					0., 0., 0.
				);

				mat3 T = transpose(mat3(czm_modelView)) * J;
				mat3 cov2d = transpose(T) * Vrk * T;

				float mid = (cov2d[0][0] + cov2d[1][1]) / 2.0;
				float radius = length(vec2((cov2d[0][0] - cov2d[1][1]) / 2.0, cov2d[0][1]));

				float lambda1 = mid + radius, lambda2 = mid - radius;

				if(lambda2 < 0.0) return;
				vec2 diagonalVector = normalize(vec2(cov2d[0][1], lambda1 - cov2d[0][0]));
				vec2 majorAxis = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector;
				vec2 minorAxis = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x);

				vColor = clamp(pos2d.z/pos2d.w+1.0, 0.0, 1.0) * vec4((cov.w) & 0xffu, (cov.w >> 8) & 0xffu, (cov.w >> 16) & 0xffu, (cov.w >> 24) & 0xffu) / 255.0;
				vPosition = position;
				vec2 vCenter = vec2(pos2d) / pos2d.w;
				gl_Position =  vec4(
					vCenter 
					+ position.x * majorAxis / czm_viewport.zw 
					+ position.y * minorAxis / czm_viewport.zw, 0.0, 1.0);

			}
		`.trim();
    let vertexSource = new Cesium.ShaderSource({
      sources: [vertexShaderSource],
    });

fragmentShaderSource

let fragmentShaderSource = `
			#version 300 es
			precision highp float;

			in vec4 vColor;
			in vec2 vPosition;

			void main () {
				float A = -dot(vPosition, vPosition);
				if (A < -4.0) discard;
				float B = exp(A) * vColor.a;
				out_FragColor = vec4(B * vColor.rgb, B);
			}
		`.trim();
let fragmentSource = new Cesium.ShaderSource({
  sources: [fragmentShaderSource],
});

splatRenderState

需要启用颜色混合。

数据更新

主要是对primitive纹理和depthIndex属性数据更新。

纹理更新

待补充。

索引更新

待补充。

Cesium源码适配splat渲染

根据Cesium版本的支持情况,按需修改。
例如我用的是cesium 1.123版本,对源码主要做了以下几处修改:
(1)增加对像素格式和类型的支持
Cesium.PixelFormat.RGBA_INTEGER、Cesium.PixelDatatype.UNSIGNED_INT
(2)增加了usampler2D类型。
(3)VertexArray增加自定义属性bindAsInteger,需要使用vertexAttribIPointer
涉及的代码文件例如:
packages\engine\Source\Core\WebGLConstants.js
packages\engine\Source\Core\webGLConstantToGlslType.js
packages\engine\Source\Core\PixelFormat.js
packages\engine\Source\Renderer\PixelDatatype.js
packages\engine\Source\Renderer\VertexArray.js

在场景中添加自定义的Primitive

在场景中添加自定义的primitive,并设置相机的视角,相关代码参考示例即可。

性能优化

WebWorker

为了提高性能,避免页面卡顿,可以结合WebWorker实现多线程。
可以使用WebWorker优化的地方,例如fetch、纹理生成、深度排序索引。

WASM

纹理生成、深度排序索引也可使用其他跨平台语言实现,通过WASM方式使用。

3DTiles

后续会单独讲。

渲染效果美化

待补充。

参考资料:
[1] https://github.com/antimatter15/splat
[2] https://github.com/CesiumGS/cesium/blob/main/Apps/Sandcastle/gallery/development/Custom%20Primitive.html
[3] https://github.com/TheBell/CesiumSplatViewer