From 9ace203bddb83c90fe4fd695670a4dcf9a090717 Mon Sep 17 00:00:00 2001
From: VishaL <119810373+vis-prime@users.noreply.github.com>
Date: Tue, 5 Dec 2023 23:30:17 +0530
Subject: [PATCH] feat: add splat mesh (#48)
* feat: add splat mesh
* fix shaderMaterial
---
.storybook/stories/Splat.stories.ts | 73 ++++
README.md | 41 ++
src/core/Splat.ts | 648 ++++++++++++++++++++++++++++
src/core/index.ts | 1 +
src/core/shaderMaterial.ts | 15 +-
5 files changed, 774 insertions(+), 4 deletions(-)
create mode 100644 .storybook/stories/Splat.stories.ts
create mode 100644 src/core/Splat.ts
diff --git a/.storybook/stories/Splat.stories.ts b/.storybook/stories/Splat.stories.ts
new file mode 100644
index 0000000..c900ed9
--- /dev/null
+++ b/.storybook/stories/Splat.stories.ts
@@ -0,0 +1,73 @@
+import * as THREE from 'three'
+import { Setup } from '../Setup'
+import { Meta } from '@storybook/html'
+import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls'
+import { GUI } from 'lil-gui'
+import { Splat, SplatLoader } from '../../src/core/Splat'
+
+export default {
+ title: 'Abstractions/Splat',
+} as Meta
+
+let gui: GUI
+
+export const SplatStory = async () => {
+ gui = new GUI({ title: 'Splat Story', closeFolders: true })
+ const { renderer, scene, camera } = Setup()
+
+ const controls = new OrbitControls(camera, renderer.domElement)
+ controls.target.set(0, 1, 0)
+ controls.update()
+
+ scene.background = new THREE.Color('white')
+
+ camera.position.set(10, 10, 10)
+
+ loadSplats(renderer, camera, scene)
+}
+
+async function loadSplats(renderer: THREE.WebGLRenderer, camera: THREE.PerspectiveCamera, scene: THREE.Scene) {
+ const cakewalk = 'https://huggingface.co/cakewalk/splat-data/resolve/main'
+ const dylanebert = 'https://huggingface.co/datasets/dylanebert/3dgs/resolve/main/kitchen'
+
+ const loader = new SplatLoader(renderer)
+ const [shoeSplat, plushSplat, kitchenSplat] = await Promise.all([
+ loader.loadAsync(`${cakewalk}/nike.splat`),
+ loader.loadAsync(`${cakewalk}/plush.splat`),
+ loader.loadAsync(`${dylanebert}/kitchen-7k.splat`),
+ ])
+
+ const shoe1 = new Splat(shoeSplat, camera, { alphaTest: 0.1 })
+ shoe1.scale.setScalar(0.5)
+ shoe1.position.set(0, 1.6, 2)
+ scene.add(shoe1)
+
+ // This will re-use the same data, only one load, one parse, one worker, one buffer
+ const shoe2 = new Splat(shoeSplat, camera, { alphaTest: 0.1 })
+ shoe2.scale.setScalar(0.5)
+ shoe2.position.set(0, 1.6, -1.5)
+ shoe2.rotation.set(Math.PI, 0, Math.PI)
+ scene.add(shoe2)
+
+ const plush = new Splat(plushSplat, camera, { alphaTest: 0.1 })
+ plush.scale.setScalar(0.5)
+ plush.position.set(-1.5, 1.6, 1)
+ scene.add(plush)
+
+ const kitchen = new Splat(kitchenSplat, camera)
+ kitchen.position.set(0, 0.25, 0)
+ scene.add(kitchen)
+
+ // add gui
+ const folder = gui.addFolder('SPLAT')
+
+ folder.add(shoe1, 'visible').name('Shoe 1 visible')
+
+ folder.add(shoe2, 'visible').name('Shoe 2 visible')
+
+ folder.add(plush, 'visible').name('Plush visible')
+
+ folder.add(kitchen, 'visible').name('Kitchen visible')
+}
+
+SplatStory.storyName = 'Default'
diff --git a/README.md b/README.md
index 951b8f6..c977b1f 100644
--- a/README.md
+++ b/README.md
@@ -55,6 +55,7 @@ import { pcss, ... } from '@pmndrs/vanilla'
Outlines
Billboard
Text
+ Splat
Gizmos
@@ -607,6 +608,46 @@ const preloadRelatedParams = {
},
```
+#### Splat
+
+[![storybook](https://img.shields.io/badge/-storybook-%23ff69b4)](https://pmndrs.github.io/drei-vanilla/?path=/story/abstractions-splat--splat-story)
+
+[drei counterpart](https://github.com/pmndrs/drei#splat)
+
+A declarative abstraction around [antimatter15/splat](https://github.com/antimatter15/splat). It supports re-use, multiple splats with correct depth sorting, splats can move and behave as a regular object3d's, supports alphahash & alphatest, and stream-loading.
+
+```js
+const loader = new SplatLoader(renderer)
+
+const [shoeSplat, plushSplat, kitchenSplat] = await Promise.all([
+ loader.loadAsync(`shoe.splat`),
+ loader.loadAsync(`plush.splat`),
+ loader.loadAsync(`kitchen.splat`),
+])
+
+const shoe1 = new Splat(shoeSplat, camera, { alphaTest: 0.1 })
+shoe1.position.set(0, 1.6, 2)
+scene.add(shoe1)
+
+// This will re-use the same data, only one load, one parse, one worker, one buffer
+const shoe2 = new Splat(shoeSplat, camera, { alphaTest: 0.1 })
+scene.add(shoe2)
+
+const plush = new Splat(plushSplat, camera, { alphaTest: 0.1 })
+scene.add(plush)
+
+const kitchen = new Splat(kitchenSplat, camera)
+scene.add(kitchen)
+```
+
+In order to depth sort multiple splats correctly you can either use alphaTest, for instance with a low value. But keep in mind that this can show a slight outline under some viewing conditions.
+
+You can also use alphaHash, but this can be slower and create some noise, you would typically get rid of the noise in postprocessing with a TAA pass. You don't have to use alphaHash on all splats.
+
+```js
+const plush = new Splat(plushSplat, camera, { alphaHash: true })
+```
+
#### Sprite Animator
[![storybook](https://img.shields.io/badge/-storybook-%23ff69b4)](https://pmndrs.github.io/drei-vanilla/?path=/story/misc-spriteanimator--sprite-animator-story)
diff --git a/src/core/Splat.ts b/src/core/Splat.ts
new file mode 100644
index 0000000..caab347
--- /dev/null
+++ b/src/core/Splat.ts
@@ -0,0 +1,648 @@
+// Based on:
+// Kevin Kwok https://github.com/antimatter15/splat
+// Quadjr https://github.com/quadjr/aframe-gaussian-splatting
+// Adapted by:
+// Paul Henschel twitter.com/0xca0a
+
+import * as THREE from 'three'
+import { shaderMaterial } from '../core/shaderMaterial'
+
+export type SplatMaterialType = {
+ alphaTest?: number
+ alphaHash?: boolean
+ centerAndScaleTexture?: THREE.DataTexture
+ covAndColorTexture?: THREE.DataTexture
+ viewport?: THREE.Vector2
+ focal?: number
+}
+
+export type TargetMesh = THREE.Mesh & {
+ ready: boolean
+ sorted: boolean
+ pm: THREE.Matrix4
+ vm1: THREE.Matrix4
+ vm2: THREE.Matrix4
+ viewport: THREE.Vector4
+}
+
+export type SharedState = {
+ url: string
+ gl: THREE.WebGLRenderer
+ worker: Worker
+ manager: THREE.LoadingManager
+ stream: ReadableStreamDefaultReader
+ loading: boolean
+ loaded: boolean
+ loadedVertexCount: number
+ rowLength: number
+ maxVertexes: number
+ chunkSize: number
+ totalDownloadBytes: number
+ numVertices: number
+ bufferTextureWidth: number
+ bufferTextureHeight: number
+ centerAndScaleData: Float32Array
+ covAndColorData: Uint32Array
+ covAndColorTexture: THREE.DataTexture
+ centerAndScaleTexture: THREE.DataTexture
+ connect(target: TargetMesh): () => void
+ update(target: TargetMesh, camera: THREE.Camera, hashed: boolean): void
+ onProgress?: (event: ProgressEvent) => void
+}
+
+const SplatMaterial = /* @__PURE__ */ shaderMaterial(
+ {
+ alphaTest: 0,
+ viewport: /* @__PURE__ */ new THREE.Vector2(1980, 1080),
+ focal: 1000.0,
+ centerAndScaleTexture: null,
+ covAndColorTexture: null,
+ },
+ /*glsl*/ `
+ precision highp sampler2D;
+ precision highp usampler2D;
+ out vec4 vColor;
+ out vec3 vPosition;
+ uniform vec2 resolution;
+ uniform vec2 viewport;
+ uniform float focal;
+ attribute uint splatIndex;
+ uniform sampler2D centerAndScaleTexture;
+ uniform usampler2D covAndColorTexture;
+
+ vec2 unpackInt16(in uint value) {
+ int v = int(value);
+ int v0 = v >> 16;
+ int v1 = (v & 0xFFFF);
+ if((v & 0x8000) != 0)
+ v1 |= 0xFFFF0000;
+ return vec2(float(v1), float(v0));
+ }
+
+ void main () {
+ ivec2 texSize = textureSize(centerAndScaleTexture, 0);
+ ivec2 texPos = ivec2(splatIndex%uint(texSize.x), splatIndex/uint(texSize.x));
+ vec4 centerAndScaleData = texelFetch(centerAndScaleTexture, texPos, 0);
+ vec4 center = vec4(centerAndScaleData.xyz, 1);
+ vec4 camspace = modelViewMatrix * center;
+ vec4 pos2d = projectionMatrix * camspace;
+
+ float bounds = 1.2 * pos2d.w;
+ if (pos2d.z < -pos2d.w || pos2d.x < -bounds || pos2d.x > bounds
+ || pos2d.y < -bounds || pos2d.y > bounds) {
+ gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
+ return;
+ }
+
+ uvec4 covAndColorData = texelFetch(covAndColorTexture, texPos, 0);
+ vec2 cov3D_M11_M12 = unpackInt16(covAndColorData.x) * centerAndScaleData.w;
+ vec2 cov3D_M13_M22 = unpackInt16(covAndColorData.y) * centerAndScaleData.w;
+ vec2 cov3D_M23_M33 = unpackInt16(covAndColorData.z) * centerAndScaleData.w;
+ mat3 Vrk = mat3(
+ cov3D_M11_M12.x, cov3D_M11_M12.y, cov3D_M13_M22.x,
+ cov3D_M11_M12.y, cov3D_M13_M22.y, cov3D_M23_M33.x,
+ cov3D_M13_M22.x, cov3D_M23_M33.x, cov3D_M23_M33.y
+ );
+
+ mat3 J = mat3(
+ focal / camspace.z, 0., -(focal * camspace.x) / (camspace.z * camspace.z),
+ 0., focal / camspace.z, -(focal * camspace.y) / (camspace.z * camspace.z),
+ 0., 0., 0.
+ );
+
+ mat3 W = transpose(mat3(modelViewMatrix));
+ mat3 T = W * J;
+ mat3 cov = transpose(T) * Vrk * T;
+ vec2 vCenter = vec2(pos2d) / pos2d.w;
+ float diagonal1 = cov[0][0] + 0.3;
+ float offDiagonal = cov[0][1];
+ float diagonal2 = cov[1][1] + 0.3;
+ float mid = 0.5 * (diagonal1 + diagonal2);
+ float radius = length(vec2((diagonal1 - diagonal2) / 2.0, offDiagonal));
+ float lambda1 = mid + radius;
+ float lambda2 = max(mid - radius, 0.1);
+ vec2 diagonalVector = normalize(vec2(offDiagonal, lambda1 - diagonal1));
+ vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector;
+ vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x);
+ uint colorUint = covAndColorData.w;
+ vColor = vec4(
+ float(colorUint & uint(0xFF)) / 255.0,
+ float((colorUint >> uint(8)) & uint(0xFF)) / 255.0,
+ float((colorUint >> uint(16)) & uint(0xFF)) / 255.0,
+ float(colorUint >> uint(24)) / 255.0
+ );
+ vPosition = position;
+
+ gl_Position = vec4(
+ vCenter
+ + position.x * v2 / viewport * 2.0
+ + position.y * v1 / viewport * 2.0, pos2d.z / pos2d.w, 1.0);
+ }
+ `,
+ /*glsl*/ `
+ #include
+ #include
+ in vec4 vColor;
+ in vec3 vPosition;
+ void main () {
+ float A = -dot(vPosition.xy, vPosition.xy);
+ if (A < -4.0) discard;
+ float B = exp(A) * vColor.a;
+ vec4 diffuseColor = vec4(vColor.rgb, B);
+ #include
+ #include
+ gl_FragColor = diffuseColor;
+ #include
+ #include <${parseInt(THREE.REVISION.replace(/\D+/g, '')) >= 154 ? 'colorspace_fragment' : 'encodings_fragment'}>
+ }
+ `
+)
+
+function createWorker(self: any) {
+ let matrices: Float32Array = null!
+ let offset = 0
+
+ function sortSplats(view: Float32Array, hashed: boolean = false) {
+ const vertexCount = matrices.length / 16
+ const threshold = -0.0001
+
+ let maxDepth = -Infinity
+ let minDepth = Infinity
+ const depthList = new Float32Array(vertexCount)
+ const sizeList = new Int32Array(depthList.buffer)
+ const validIndexList = new Int32Array(vertexCount)
+
+ let validCount = 0
+ for (let i = 0; i < vertexCount; i++) {
+ // Sign of depth is reversed
+ const depth =
+ view[0] * matrices[i * 16 + 12] + view[1] * matrices[i * 16 + 13] + view[2] * matrices[i * 16 + 14] + view[3]
+ // Skip behind of camera and small, transparent splat
+ if (hashed || (depth < 0 && matrices[i * 16 + 15] > threshold * depth)) {
+ depthList[validCount] = depth
+ validIndexList[validCount] = i
+ validCount++
+ if (depth > maxDepth) maxDepth = depth
+ if (depth < minDepth) minDepth = depth
+ }
+ }
+
+ // This is a 16 bit single-pass counting sort
+ const depthInv = (256 * 256 - 1) / (maxDepth - minDepth)
+ const counts0 = new Uint32Array(256 * 256)
+ for (let i = 0; i < validCount; i++) {
+ sizeList[i] = ((depthList[i] - minDepth) * depthInv) | 0
+ counts0[sizeList[i]]++
+ }
+ const starts0 = new Uint32Array(256 * 256)
+ for (let i = 1; i < 256 * 256; i++) starts0[i] = starts0[i - 1] + counts0[i - 1]
+ const depthIndex = new Uint32Array(validCount)
+ for (let i = 0; i < validCount; i++) depthIndex[starts0[sizeList[i]]++] = validIndexList[i]
+ return depthIndex
+ }
+
+ self.onmessage = (e: {
+ data: { method: string; length: number; key: string; view: Float32Array; matrices: Float32Array; hashed: boolean }
+ }) => {
+ if (e.data.method == 'push') {
+ if (offset === 0) matrices = new Float32Array(e.data.length)
+ const new_matrices = new Float32Array(e.data.matrices)
+ matrices.set(new_matrices, offset)
+ offset += new_matrices.length
+ } else if (e.data.method == 'sort') {
+ if (matrices !== null) {
+ const indices = sortSplats(new Float32Array(e.data.view), e.data.hashed)
+ // @ts-ignore
+ self.postMessage({ indices, key: e.data.key }, [indices.buffer])
+ }
+ }
+ }
+}
+
+export class SplatLoader extends THREE.Loader {
+ // WebGLRenderer, needs to be filled out!
+ gl: THREE.WebGLRenderer
+ // Default chunk size for lazy loading
+ chunkSize: number
+
+ constructor(gl: THREE.WebGLRenderer, chunkSize = 25000) {
+ super()
+ this.gl = gl
+ this.chunkSize = chunkSize
+ }
+
+ async loadAsync(url: string, onProgress?: (event: ProgressEvent) => void, onError?: (event: ErrorEvent) => void) {
+ return new Promise((res) => this.load(url, res, onProgress, onError))
+ }
+
+ load(
+ url: string,
+ onLoad: (data: SharedState) => void,
+ onProgress?: (event: ProgressEvent) => void,
+ onError?: (event: ErrorEvent) => void
+ ) {
+ const shared = {
+ gl: this.gl,
+ url: this.manager.resolveURL(url),
+ worker: new Worker(
+ URL.createObjectURL(
+ new Blob(['(', createWorker.toString(), ')(self)'], {
+ type: 'application/javascript',
+ })
+ )
+ ),
+ manager: this.manager,
+ update: (target: TargetMesh, camera: THREE.Camera, hashed: boolean) => update(camera, shared, target, hashed),
+ connect: (target: TargetMesh) => connect(shared, target),
+ loading: false,
+ loaded: false,
+ loadedVertexCount: 0,
+ chunkSize: this.chunkSize,
+ totalDownloadBytes: 0,
+ numVertices: 0,
+ rowLength: 3 * 4 + 3 * 4 + 4 + 4,
+ maxVertexes: 0,
+ bufferTextureWidth: 0,
+ bufferTextureHeight: 0,
+ stream: null!,
+ centerAndScaleData: null!,
+ covAndColorData: null!,
+ covAndColorTexture: null!,
+ centerAndScaleTexture: null!,
+ onProgress,
+ }
+ load(shared)
+ .then(onLoad)
+ .catch((e) => {
+ onError?.(e)
+ shared.manager.itemError(shared.url)
+ })
+ }
+}
+
+async function load(shared: SharedState) {
+ shared.manager.itemStart(shared.url)
+ const data = await fetch(shared.url)
+
+ if (data.body === null) throw 'Failed to fetch file'
+ const _totalDownloadBytes = data.headers.get('Content-Length')
+ const totalDownloadBytes = _totalDownloadBytes ? parseInt(_totalDownloadBytes) : undefined
+ if (totalDownloadBytes == undefined) throw 'Failed to get content length'
+ shared.stream = data.body.getReader()
+ shared.totalDownloadBytes = totalDownloadBytes
+ shared.numVertices = Math.floor(shared.totalDownloadBytes / shared.rowLength)
+ const context = shared.gl.getContext()
+ const maxTextureSize = context.getParameter(context.MAX_TEXTURE_SIZE)
+ shared.maxVertexes = maxTextureSize * maxTextureSize
+
+ if (shared.numVertices > shared.maxVertexes) shared.numVertices = shared.maxVertexes
+ shared.bufferTextureWidth = maxTextureSize
+ shared.bufferTextureHeight = Math.floor((shared.numVertices - 1) / maxTextureSize) + 1
+
+ shared.centerAndScaleData = new Float32Array(shared.bufferTextureWidth * shared.bufferTextureHeight * 4)
+ shared.covAndColorData = new Uint32Array(shared.bufferTextureWidth * shared.bufferTextureHeight * 4)
+ shared.centerAndScaleTexture = new THREE.DataTexture(
+ shared.centerAndScaleData,
+ shared.bufferTextureWidth,
+ shared.bufferTextureHeight,
+ THREE.RGBAFormat,
+ THREE.FloatType
+ )
+
+ shared.centerAndScaleTexture.needsUpdate = true
+ shared.covAndColorTexture = new THREE.DataTexture(
+ shared.covAndColorData,
+ shared.bufferTextureWidth,
+ shared.bufferTextureHeight,
+ THREE.RGBAIntegerFormat,
+ THREE.UnsignedIntType
+ )
+ shared.covAndColorTexture.internalFormat = 'RGBA32UI'
+ shared.covAndColorTexture.needsUpdate = true
+ return shared
+}
+
+async function lazyLoad(shared: SharedState) {
+ shared.loading = true
+ let bytesDownloaded = 0
+ let bytesProcessed = 0
+ const chunks: Array = []
+ let lastReportedProgress = 0
+ const lengthComputable = shared.totalDownloadBytes !== 0
+ while (true) {
+ try {
+ const { value, done } = await shared.stream.read()
+ if (done) break
+ bytesDownloaded += value.length
+
+ if (shared.totalDownloadBytes != undefined) {
+ const percent = (bytesDownloaded / shared.totalDownloadBytes) * 100
+ if (shared.onProgress && percent - lastReportedProgress > 1) {
+ const event = new ProgressEvent('progress', {
+ lengthComputable,
+ loaded: bytesDownloaded,
+ total: shared.totalDownloadBytes,
+ })
+ shared.onProgress(event)
+ lastReportedProgress = percent
+ }
+ }
+
+ chunks.push(value)
+ const bytesRemains = bytesDownloaded - bytesProcessed
+ if (shared.totalDownloadBytes != undefined && bytesRemains > shared.rowLength * shared.chunkSize) {
+ const vertexCount = Math.floor(bytesRemains / shared.rowLength)
+ const concatenatedChunksbuffer = new Uint8Array(bytesRemains)
+ let offset = 0
+ for (const chunk of chunks) {
+ concatenatedChunksbuffer.set(chunk, offset)
+ offset += chunk.length
+ }
+ chunks.length = 0
+ if (bytesRemains > vertexCount * shared.rowLength) {
+ const extra_data = new Uint8Array(bytesRemains - vertexCount * shared.rowLength)
+ extra_data.set(concatenatedChunksbuffer.subarray(bytesRemains - extra_data.length, bytesRemains), 0)
+ chunks.push(extra_data)
+ }
+ const buffer = new Uint8Array(vertexCount * shared.rowLength)
+ buffer.set(concatenatedChunksbuffer.subarray(0, buffer.byteLength), 0)
+ const matrices = pushDataBuffer(shared, buffer.buffer, vertexCount)
+ shared.worker.postMessage(
+ { method: 'push', src: shared.url, length: shared.numVertices * 16, matrices: matrices.buffer },
+ [matrices.buffer]
+ )
+ bytesProcessed += vertexCount * shared.rowLength
+
+ if (shared.onProgress) {
+ const event = new ProgressEvent('progress', {
+ lengthComputable,
+ loaded: shared.totalDownloadBytes,
+ total: shared.totalDownloadBytes,
+ })
+ shared.onProgress(event)
+ }
+ }
+ } catch (error) {
+ console.error(error)
+ break
+ }
+ }
+
+ if (bytesDownloaded - bytesProcessed > 0) {
+ // Concatenate the chunks into a single Uint8Array
+ const concatenatedChunks = new Uint8Array(chunks.reduce((acc, chunk) => acc + chunk.length, 0))
+ let offset = 0
+ for (const chunk of chunks) {
+ concatenatedChunks.set(chunk, offset)
+ offset += chunk.length
+ }
+ const numVertices = Math.floor(concatenatedChunks.byteLength / shared.rowLength)
+ const matrices = pushDataBuffer(shared, concatenatedChunks.buffer, numVertices)
+ shared.worker.postMessage(
+ { method: 'push', src: shared.url, length: numVertices * 16, matrices: matrices.buffer },
+ [matrices.buffer]
+ )
+ }
+ shared.loaded = true
+ shared.manager.itemEnd(shared.url)
+}
+
+function update(camera: THREE.Camera, shared: SharedState, target: TargetMesh, hashed: boolean) {
+ camera.updateMatrixWorld()
+ shared.gl.getCurrentViewport(target.viewport)
+ // @ts-ignore
+ target.material.viewport.x = target.viewport.z
+ // @ts-ignore
+ target.material.viewport.y = target.viewport.w
+ target.material.focal = (target.viewport.w / 2.0) * Math.abs(camera.projectionMatrix.elements[5])
+
+ if (target.ready) {
+ if (hashed && target.sorted) return
+ target.ready = false
+ const view = new Float32Array([
+ target.modelViewMatrix.elements[2],
+ -target.modelViewMatrix.elements[6],
+ target.modelViewMatrix.elements[10],
+ target.modelViewMatrix.elements[14],
+ ])
+ shared.worker.postMessage({ method: 'sort', src: shared.url, key: target.uuid, view: view.buffer, hashed }, [
+ view.buffer,
+ ])
+ if (hashed && shared.loaded) target.sorted = true
+ }
+}
+
+function connect(shared: SharedState, target: TargetMesh) {
+ if (!shared.loading) lazyLoad(shared)
+
+ target.ready = false
+ target.pm = new THREE.Matrix4()
+ target.vm1 = new THREE.Matrix4()
+ target.vm2 = new THREE.Matrix4()
+ target.viewport = new THREE.Vector4()
+
+ const splatIndexArray = new Uint32Array(shared.bufferTextureWidth * shared.bufferTextureHeight)
+ const splatIndexes = new THREE.InstancedBufferAttribute(splatIndexArray, 1, false)
+ splatIndexes.setUsage(THREE.DynamicDrawUsage)
+
+ const geometry = (target.geometry = new THREE.InstancedBufferGeometry())
+ const positionsArray = new Float32Array(6 * 3)
+ const positions = new THREE.BufferAttribute(positionsArray, 3)
+ geometry.setAttribute('position', positions)
+ positions.setXYZ(2, -2.0, 2.0, 0.0)
+ positions.setXYZ(1, 2.0, 2.0, 0.0)
+ positions.setXYZ(0, -2.0, -2.0, 0.0)
+ positions.setXYZ(5, -2.0, -2.0, 0.0)
+ positions.setXYZ(4, 2.0, 2.0, 0.0)
+ positions.setXYZ(3, 2.0, -2.0, 0.0)
+ positions.needsUpdate = true
+ geometry.setAttribute('splatIndex', splatIndexes)
+ geometry.instanceCount = 1
+
+ function listener(e: { data: { key: string; indices: Uint32Array } }) {
+ if (target && e.data.key === target.uuid) {
+ const indexes = new Uint32Array(e.data.indices)
+ // @ts-ignore
+ geometry.attributes.splatIndex.set(indexes)
+ geometry.attributes.splatIndex.needsUpdate = true
+ geometry.instanceCount = indexes.length
+ target.ready = true
+ }
+ }
+ shared.worker.addEventListener('message', listener)
+
+ async function wait() {
+ while (true) {
+ const centerAndScaleTextureProperties = shared.gl.properties.get(shared.centerAndScaleTexture)
+ const covAndColorTextureProperties = shared.gl.properties.get(shared.covAndColorTexture)
+ if (
+ centerAndScaleTextureProperties?.__webglTexture &&
+ covAndColorTextureProperties?.__webglTexture &&
+ shared.loadedVertexCount > 0
+ ) {
+ break
+ }
+ await new Promise((resolve) => setTimeout(resolve, 10))
+ }
+ target.ready = true
+ }
+
+ wait()
+ return () => shared.worker.removeEventListener('message', listener)
+}
+
+function pushDataBuffer(shared: SharedState, buffer: ArrayBufferLike, vertexCount: number) {
+ const context = shared.gl.getContext()
+ if (shared.loadedVertexCount + vertexCount > shared.maxVertexes) {
+ vertexCount = shared.maxVertexes - shared.loadedVertexCount
+ }
+ if (vertexCount <= 0) throw 'Failed to parse file'
+
+ const u_buffer = new Uint8Array(buffer)
+ const f_buffer = new Float32Array(buffer)
+ const matrices = new Float32Array(vertexCount * 16)
+
+ const covAndColorData_uint8 = new Uint8Array(shared.covAndColorData.buffer)
+ const covAndColorData_int16 = new Int16Array(shared.covAndColorData.buffer)
+ for (let i = 0; i < vertexCount; i++) {
+ const quat = new THREE.Quaternion(
+ -(u_buffer[32 * i + 28 + 1] - 128) / 128.0,
+ (u_buffer[32 * i + 28 + 2] - 128) / 128.0,
+ (u_buffer[32 * i + 28 + 3] - 128) / 128.0,
+ -(u_buffer[32 * i + 28 + 0] - 128) / 128.0
+ )
+ quat.invert()
+ const center = new THREE.Vector3(f_buffer[8 * i + 0], f_buffer[8 * i + 1], -f_buffer[8 * i + 2])
+ const scale = new THREE.Vector3(f_buffer[8 * i + 3 + 0], f_buffer[8 * i + 3 + 1], f_buffer[8 * i + 3 + 2])
+
+ const mtx = new THREE.Matrix4()
+ mtx.makeRotationFromQuaternion(quat)
+ mtx.transpose()
+ mtx.scale(scale)
+ const mtx_t = mtx.clone()
+ mtx.transpose()
+ mtx.premultiply(mtx_t)
+ mtx.setPosition(center)
+
+ const cov_indexes = [0, 1, 2, 5, 6, 10]
+ let max_value = 0.0
+ for (let j = 0; j < cov_indexes.length; j++) {
+ if (Math.abs(mtx.elements[cov_indexes[j]]) > max_value) {
+ max_value = Math.abs(mtx.elements[cov_indexes[j]])
+ }
+ }
+
+ let destOffset = shared.loadedVertexCount * 4 + i * 4
+ shared.centerAndScaleData[destOffset + 0] = center.x
+ shared.centerAndScaleData[destOffset + 1] = -center.y
+ shared.centerAndScaleData[destOffset + 2] = center.z
+ shared.centerAndScaleData[destOffset + 3] = max_value / 32767.0
+
+ destOffset = shared.loadedVertexCount * 8 + i * 4 * 2
+ for (let j = 0; j < cov_indexes.length; j++) {
+ covAndColorData_int16[destOffset + j] = (mtx.elements[cov_indexes[j]] * 32767.0) / max_value
+ }
+
+ // RGBA
+ destOffset = shared.loadedVertexCount * 16 + (i * 4 + 3) * 4
+ const col = new THREE.Color(
+ u_buffer[32 * i + 24 + 0] / 255,
+ u_buffer[32 * i + 24 + 1] / 255,
+ u_buffer[32 * i + 24 + 2] / 255
+ )
+ col.convertSRGBToLinear()
+ covAndColorData_uint8[destOffset + 0] = col.r * 255
+ covAndColorData_uint8[destOffset + 1] = col.g * 255
+ covAndColorData_uint8[destOffset + 2] = col.b * 255
+ covAndColorData_uint8[destOffset + 3] = u_buffer[32 * i + 24 + 3]
+
+ // Store scale and transparent to remove splat in sorting process
+ mtx.elements[15] = (Math.max(scale.x, scale.y, scale.z) * u_buffer[32 * i + 24 + 3]) / 255.0
+ for (let j = 0; j < 16; j++) matrices[i * 16 + j] = mtx.elements[j]
+ }
+
+ while (vertexCount > 0) {
+ let width = 0
+ let height = 0
+ const xoffset = shared.loadedVertexCount % shared.bufferTextureWidth
+ const yoffset = Math.floor(shared.loadedVertexCount / shared.bufferTextureWidth)
+ if (shared.loadedVertexCount % shared.bufferTextureWidth != 0) {
+ width = Math.min(shared.bufferTextureWidth, xoffset + vertexCount) - xoffset
+ height = 1
+ } else if (Math.floor(vertexCount / shared.bufferTextureWidth) > 0) {
+ width = shared.bufferTextureWidth
+ height = Math.floor(vertexCount / shared.bufferTextureWidth)
+ } else {
+ width = vertexCount % shared.bufferTextureWidth
+ height = 1
+ }
+
+ const centerAndScaleTextureProperties = shared.gl.properties.get(shared.centerAndScaleTexture)
+ context.bindTexture(context.TEXTURE_2D, centerAndScaleTextureProperties.__webglTexture)
+ context.texSubImage2D(
+ context.TEXTURE_2D,
+ 0,
+ xoffset,
+ yoffset,
+ width,
+ height,
+ context.RGBA,
+ context.FLOAT,
+ shared.centerAndScaleData,
+ shared.loadedVertexCount * 4
+ )
+
+ const covAndColorTextureProperties = shared.gl.properties.get(shared.covAndColorTexture)
+ context.bindTexture(context.TEXTURE_2D, covAndColorTextureProperties.__webglTexture)
+ context.texSubImage2D(
+ context.TEXTURE_2D,
+ 0,
+ xoffset,
+ yoffset,
+ width,
+ height,
+ // @ts-ignore
+ context.RGBA_INTEGER,
+ context.UNSIGNED_INT,
+ shared.covAndColorData,
+ shared.loadedVertexCount * 4
+ )
+ shared.gl.resetState()
+
+ shared.loadedVertexCount += width * height
+ vertexCount -= width * height
+ }
+ return matrices
+}
+
+export class Splat extends THREE.Mesh {
+ camera: THREE.Camera
+ shared: any
+ alphaHash: boolean
+ constructor(shared: any, camera: THREE.Camera, { toneMapped = false, alphaTest = 0, alphaHash = false } = {}) {
+ super()
+ this.shared = shared
+ this.camera = camera
+ this.alphaHash = alphaHash
+ this.frustumCulled = false
+ this.onBeforeRender = () => shared.update(this, camera, alphaHash)
+ this.material = new SplatMaterial()
+ Object.assign(this.material, {
+ transparent: !alphaHash,
+ depthTest: true,
+ alphaTest: alphaHash ? 0 : alphaTest,
+ centerAndScaleTexture: shared.centerAndScaleTexture,
+ covAndColorTexture: shared.covAndColorTexture,
+ depthWrite: alphaHash ? true : alphaTest > 0,
+ blending: alphaHash ? THREE.NormalBlending : THREE.CustomBlending,
+ blendSrcAlpha: THREE.OneFactor,
+ alphaHash: !!alphaHash,
+ toneMapped: toneMapped,
+ })
+ shared.connect(this)
+ }
+
+ splatUpdate() {
+ this.shared.update(this, this.camera, this.alphaHash)
+ }
+}
diff --git a/src/core/index.ts b/src/core/index.ts
index cefce5c..595bf08 100644
--- a/src/core/index.ts
+++ b/src/core/index.ts
@@ -14,6 +14,7 @@ export * from './SpriteAnimator'
export * from './Outlines'
export * from './Billboard'
export * from './Text'
+export * from './Splat'
// Gizmos
export * from './Grid'
diff --git a/src/core/shaderMaterial.ts b/src/core/shaderMaterial.ts
index 02c0c76..aafbd40 100644
--- a/src/core/shaderMaterial.ts
+++ b/src/core/shaderMaterial.ts
@@ -34,15 +34,22 @@ export function shaderMaterial(
onInit?: (material: ShaderMaterialInstance) => void
) {
const entries = Object.entries(uniforms)
- const uniformDefs = Object.fromEntries(entries.map(([name, value]) => [name, { value }])) as {
- [K in keyof TProps]: { value: TProps[K] }
- }
class Material extends THREE.ShaderMaterial {
static key = THREE.MathUtils.generateUUID()
constructor(parameters?: ShaderMaterialParameters) {
- super({ uniforms: uniformDefs, vertexShader, fragmentShader })
+ super({
+ uniforms: entries.reduce((acc, [name, value]) => {
+ const uniform = THREE.UniformsUtils.clone({ [name]: { value } })
+ return {
+ ...acc,
+ ...uniform,
+ }
+ }, {}),
+ vertexShader,
+ fragmentShader,
+ })
for (const [name] of entries) {
Object.defineProperty(this, name, {