diff --git a/setup.sh b/setup.sh index 8cb9461..9308eb9 100755 --- a/setup.sh +++ b/setup.sh @@ -1,5 +1,5 @@ # Read Arguments -TEMP=`getopt -o h --long help,new-env,basic,xformers,flash-attn,diffoctreerast,vox2seq,spconv,mipgaussian,kaolin,nvdiffrast,demo -n 'setup.sh' -- "$@"` +TEMP=`getopt -o h --long help,new-env,basic,xformers,flash-attn,diffoctreerast,vox2seq,spconv,mipgaussian,kaolin,nvdiffrast,gsplat,demo -n 'setup.sh' -- "$@"` eval set -- "$TEMP" @@ -16,6 +16,7 @@ ERROR=false MIPGAUSSIAN=false KAOLIN=false NVDIFFRAST=false +GSPLAT=false DEMO=false if [ "$#" -eq 1 ] ; then @@ -35,6 +36,7 @@ while true ; do --mipgaussian) MIPGAUSSIAN=true ; shift ;; --kaolin) KAOLIN=true ; shift ;; --nvdiffrast) NVDIFFRAST=true ; shift ;; + --gsplat) GSPLAT=true ; shift ;; --demo) DEMO=true ; shift ;; --) shift ; break ;; *) ERROR=true ; break ;; @@ -60,6 +62,7 @@ if [ "$HELP" = true ] ; then echo " --mipgaussian Install mip-splatting" echo " --kaolin Install kaolin" echo " --nvdiffrast Install nvdiffrast" + echo " --gsplat Install gsplat" echo " --demo Install all dependencies for demo" return fi @@ -248,3 +251,11 @@ fi if [ "$DEMO" = true ] ; then pip install gradio==4.44.1 gradio_litmodel3d==0.0.1 fi + +if [ "$GSPLAT" = true ] ; then + if [ "$PLATFORM" = "cuda" ] ; then + pip install git+https://github.com/nerfstudio-project/gsplat + else + echo "[GSPLAT] Unsupported platform: $PLATFORM" + fi +fi \ No newline at end of file diff --git a/trellis/renderers/__init__.py b/trellis/renderers/__init__.py index 0339355..489814d 100755 --- a/trellis/renderers/__init__.py +++ b/trellis/renderers/__init__.py @@ -3,6 +3,7 @@ __attributes = { 'OctreeRenderer': 'octree_renderer', 'GaussianRenderer': 'gaussian_render', + 'GSplatRenderer': 'gsplat_renderer', 'MeshRenderer': 'mesh_renderer', } @@ -28,4 +29,5 @@ def __getattr__(name): if __name__ == '__main__': from .octree_renderer import OctreeRenderer from .gaussian_render import GaussianRenderer + from .gsplat_renderer import GSplatRenderer from .mesh_renderer import MeshRenderer \ No newline at end of file diff --git a/trellis/renderers/gsplat_renderer.py b/trellis/renderers/gsplat_renderer.py new file mode 100644 index 0000000..31579db --- /dev/null +++ b/trellis/renderers/gsplat_renderer.py @@ -0,0 +1,108 @@ +import gsplat as gs +import numpy as np +import torch +import torch.nn.functional as F +from easydict import EasyDict as edict + + +class GSplatRenderer: + def __init__(self, rendering_options={}) -> None: + self.pipe = edict({ + "kernel_size": 0.1, + "convert_SHs_python": False, + "compute_cov3D_python": False, + "scale_modifier": 1.0, + "debug": False, + "use_mip_gaussian": True + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + gaussian, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None + ) -> edict: + + resolution = self.rendering_options["resolution"] + ssaa = self.rendering_options["ssaa"] + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor( + self.rendering_options["bg_color"], + dtype=torch.float32, + device="cuda" + ) + + height = resolution * ssaa + width = resolution * ssaa + + # Set up background color + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor( + self.rendering_options["bg_color"], + dtype=torch.float32, + device="cuda" + ) + + Ks_scaled = intrinsics.clone() + Ks_scaled[0, 0] *= width + Ks_scaled[1, 1] *= height + Ks_scaled[0, 2] *= width + Ks_scaled[1, 2] *= height + Ks_scaled = Ks_scaled.unsqueeze(0) + + near_plane = 0.01 + far_plane = 1000.0 + + # Rasterize with gsplat + render_colors, render_alphas, meta = gs.rasterization( + means=gaussian.get_xyz, + quats=F.normalize(gaussian.get_rotation, dim=-1), + scales=gaussian.get_scaling / intrinsics[0, 0], + opacities=gaussian.get_opacity.squeeze(-1), + colors=colors_overwrite.unsqueeze(0) if colors_overwrite is not None else torch.sigmoid( + gaussian.get_features.squeeze(1)).unsqueeze(0), + viewmats=extrinsics.unsqueeze(0), + Ks=Ks_scaled, + width=width, + height=height, + near_plane=near_plane, + far_plane=far_plane, + radius_clip=3.0, + eps2d=0.3, + render_mode="RGB", + backgrounds=self.bg_color.unsqueeze(0), + camera_model="pinhole" + ) + + rendered_image = render_colors[0, ..., 0:3].permute(2, 0, 1) + + # Apply supersampling if needed + if ssaa > 1: + rendered_image = F.interpolate( + rendered_image[None], + size=(resolution, resolution), + mode='bilinear', + align_corners=False, + antialias=True + ).squeeze() + + return edict({'color': rendered_image}) diff --git a/trellis/utils/postprocessing_utils.py b/trellis/utils/postprocessing_utils.py index 4905d38..55f0314 100644 --- a/trellis/utils/postprocessing_utils.py +++ b/trellis/utils/postprocessing_utils.py @@ -404,6 +404,7 @@ def to_glb( texture_size: int = 1024, debug: bool = False, verbose: bool = True, + gs_renderer='gsplat', ) -> trimesh.Trimesh: """ Convert a generated asset to a glb file. @@ -417,6 +418,7 @@ def to_glb( texture_size (int): Size of the texture. debug (bool): Whether to print debug information. verbose (bool): Whether to print progress. + gs_renderer (str): Name of the renderer to use for gaussian splatting rendering. """ vertices = mesh.vertices.cpu().numpy() faces = mesh.faces.cpu().numpy() @@ -432,14 +434,15 @@ def to_glb( fill_holes_resolution=1024, fill_holes_num_views=1000, debug=debug, - verbose=verbose, + verbose=verbose ) # parametrize mesh vertices, faces, uvs = parametrize_mesh(vertices, faces) # bake texture - observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100) + observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100, + gs_renderer=gs_renderer) masks = [np.any(observation > 0, axis=-1) for observation in observations] extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] diff --git a/trellis/utils/render_utils.py b/trellis/utils/render_utils.py index 8187c84..e80d7ff 100644 --- a/trellis/utils/render_utils.py +++ b/trellis/utils/render_utils.py @@ -4,7 +4,7 @@ import utils3d from PIL import Image -from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer +from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer, GSplatRenderer from ..representations import Octree, Gaussian, MeshExtractResult from ..modules import sparse as sp from .random_utils import sphere_hammersley_sequence @@ -40,7 +40,7 @@ def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): return extrinsics, intrinsics -def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs): +def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, gs_renderer='gsplat', **kwargs): if isinstance(sample, Octree): renderer = OctreeRenderer() renderer.rendering_options.resolution = options.get('resolution', 512) @@ -50,7 +50,10 @@ def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=N renderer.rendering_options.ssaa = options.get('ssaa', 4) renderer.pipe.primitive = sample.primitive elif isinstance(sample, Gaussian): - renderer = GaussianRenderer() + if gs_renderer == 'gsplat': + renderer = GSplatRenderer() + else: + renderer = GaussianRenderer() renderer.rendering_options.resolution = options.get('resolution', 512) renderer.rendering_options.near = options.get('near', 0.8) renderer.rendering_options.far = options.get('far', 1.6) @@ -96,14 +99,15 @@ def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2 return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) -def render_multiview(sample, resolution=512, nviews=30): +def render_multiview(sample, resolution=512, nviews=30, gs_renderer='gsplat'): r = 2 fov = 40 cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] yaws = [cam[0] for cam in cams] pitchs = [cam[1] for cam in cams] extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) - res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) + res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}, + gs_renderer=gs_renderer) return res['color'], extrinsics, intrinsics @@ -113,4 +117,4 @@ def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 1 yaw = [y + yaw_offset for y in yaw] pitch = [offset[1] for _ in range(4)] extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) - return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) \ No newline at end of file