d3c83053f2
- CMakeLists.txt - app.md - appicon.ico - backend/ - main.cpp Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
94 lines
2.8 KiB
Python
94 lines
2.8 KiB
Python
"""Backend TripoSR (Stability + Tripo, MIT).
|
|
|
|
Asume que `sources/TripoSR` esta clonado en el registry. Importa `tsr.system.TSR`.
|
|
Descarga checkpoint desde HF en la primera carga (~1.2 GB).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
import os
|
|
import pathlib
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict
|
|
|
|
import numpy as np
|
|
import torch
|
|
import trimesh
|
|
from PIL import Image
|
|
|
|
|
|
def _ensure_sources_on_path() -> pathlib.Path:
|
|
root = pathlib.Path(os.environ.get("FN_REGISTRY_ROOT", "/home/lucas/fn_registry"))
|
|
src = root / "sources" / "TripoSR"
|
|
if not src.exists():
|
|
raise RuntimeError(
|
|
f"TripoSR no clonado en {src}. "
|
|
"git clone --depth=1 https://github.com/VAST-AI-Research/TripoSR.git "
|
|
f"{src}"
|
|
)
|
|
if str(src) not in sys.path:
|
|
sys.path.insert(0, str(src))
|
|
return src
|
|
|
|
|
|
@dataclass
|
|
class Handle:
|
|
model: Any
|
|
rembg_session: Any
|
|
device: str
|
|
|
|
def infer(self, image: Image.Image, cfg: Dict[str, Any]) -> bytes:
|
|
from tsr.utils import remove_background, resize_foreground
|
|
|
|
fg_ratio = float(cfg.get("foreground_ratio", 0.85))
|
|
mc_res = int(cfg.get("mc_resolution", 256))
|
|
|
|
fg = remove_background(image, self.rembg_session)
|
|
fg = resize_foreground(fg, fg_ratio)
|
|
|
|
# Composite RGBA -> RGB sobre gris 0.5 (preprocesado canonico TripoSR
|
|
# run.py). Sin esto el tokenizer DINO recibe 4 canales y peta:
|
|
# "The size of tensor a (4) must match tensor b (3) at dim 2".
|
|
arr = np.asarray(fg).astype(np.float32) / 255.0
|
|
if arr.shape[-1] == 4:
|
|
arr = arr[:, :, :3] * arr[:, :, 3:4] + (1.0 - arr[:, :, 3:4]) * 0.5
|
|
fg = Image.fromarray((arr * 255.0).astype(np.uint8))
|
|
|
|
with torch.no_grad():
|
|
scene_codes = self.model([fg], device=self.device)
|
|
meshes = self.model.extract_mesh(
|
|
scene_codes, has_vertex_color=False, resolution=mc_res
|
|
)
|
|
m = meshes[0]
|
|
tm = trimesh.Trimesh(
|
|
vertices=np.asarray(m.vertices),
|
|
faces=np.asarray(m.faces),
|
|
process=True,
|
|
)
|
|
buf = io.BytesIO()
|
|
tm.export(buf, file_type="glb")
|
|
return buf.getvalue()
|
|
|
|
def close(self) -> None:
|
|
del self.model
|
|
del self.rembg_session
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def load() -> Handle:
|
|
_ensure_sources_on_path()
|
|
from tsr.system import TSR
|
|
from rembg import new_session
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = TSR.from_pretrained(
|
|
"stabilityai/TripoSR",
|
|
config_name="config.yaml",
|
|
weight_name="model.ckpt",
|
|
)
|
|
model.renderer.set_chunk_size(8192)
|
|
model.to(device)
|
|
return Handle(model=model, rembg_session=new_session(), device=device)
|