code
dataset_mesh.py
import numpy as npimport torch
from render import utilfrom render import meshfrom render import renderfrom render import light
from .dataset import Dataset
################################################################################ Reference dataset using mesh & rendering###############################################################################
class DatasetMesh(Dataset):
def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False): # Init self.glctx = glctx self.cam_radius = cam_radius self.FLAGS = FLAGS self.validate = validate self.fovy = np.deg2rad(45) self.aspect = FLAGS.train_res[1] / FLAGS.train_res[0]
if self.FLAGS.local_rank == 0: print("DatasetMesh: ref mesh has %d triangles and %d vertices" % (ref_mesh.t_pos_idx.shape[0], ref_mesh.v_pos.shape[0]))
# Sanity test training texture resolution ref_texture_res = np.maximum(ref_mesh.material['kd'].getRes(), ref_mesh.material['ks'].getRes()) if 'normal' in ref_mesh.material: ref_texture_res = np.maximum(ref_texture_res, ref_mesh.material['normal'].getRes()) if self.FLAGS.local_rank == 0 and FLAGS.texture_res[0] < ref_texture_res[0] or FLAGS.texture_res[1] < ref_texture_res[1]: print("---> WARNING: Picked a texture resolution lower than the reference mesh [%d, %d] < [%d, %d]" % (FLAGS.texture_res[0], FLAGS.texture_res[1], ref_texture_res[0], ref_texture_res[1]))
# Load environment map texture self.envlight = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
self.ref_mesh = mesh.compute_tangents(ref_mesh)
def _rotate_scene(self, itr): proj_mtx = util.perspective(self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
# Smooth rotation for display. ang = (itr / 50) * np.pi * 2 mv = util.translate(0, 0, -self.cam_radius) @ (util.rotate_x(0.4) @ util.rotate_y(ang)) mvp = proj_mtx @ mv campos = torch.linalg.inv(mv)[:3, 3]
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), self.FLAGS.display_res, self.FLAGS.spp
def _random_scene(self): # =========================================================== # Setup projection matrix # ========================================================== iter_res = self.FLAGS.train_res proj_mtx = util.perspective(self.fovy, iter_res[1] / iter_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
# ============================================================== # Random camera & light position # ==============================================================
# Random rotation/translation matrix for optimization. mv = util.translate(0, 0, -self.cam_radius) @ util.random_rotation_translation(0.25) mvp = proj_mtx @ mv campos = torch.linalg.inv(mv)[:3, 3]
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), iter_res, self.FLAGS.spp # Add batch dimension
def __len__(self): return 50 if self.validate else (self.FLAGS.iter + 1) * self.FLAGS.batch
def __getitem__(self, itr): # =============================================================== # Randomize scene parameters # ==============================================================
if self.validate: mv, mvp, campos, iter_res, iter_spp = self._rotate_scene(itr) else: mv, mvp, campos, iter_res, iter_spp = self._random_scene()
img = render.render_mesh(self.glctx, self.ref_mesh, mvp, campos, self.envlight, iter_res, spp=iter_spp, num_layers=self.FLAGS.layers, msaa=True, background=None)['shaded']
^eebb2a
return { 'mv' : mv, 'mvp' : mvp, 'campos' : campos, 'resolution' : iter_res, 'spp' : iter_spp, 'img' : img }
light.py
import osimport numpy as npimport torchimport nvdiffrast.torch as dr
from . import utilfrom . import renderutils as ru
############################################### Utility functions#############################################
class cubemap_mip(torch.autograd.Function): @staticmethod def forward(ctx, cubemap): return util.avg_pool_nhwc(cubemap, (2,2))
@staticmethod def backward(ctx, dout): res = dout.shape[1] * 2 out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda") for s in range(6): gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), indexing='ij') v = util.safe_normalize(util.cube_to_dir(s, gx, gy)) out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube') return out
#################################################### Split-sum environment map light source with automatic mipmap generation###################################################
class EnvironmentLight(torch.nn.Module): LIGHT_MIN_RES = 16
MIN_ROUGHNESS = 0.08 MAX_ROUGHNESS = 0.5
def __init__(self, base): super(EnvironmentLight, self).__init__() self.mtx = None self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=True) self.register_parameter('env_base', self.base)
def xfm(self, mtx): self.mtx = mtx
def clone(self): return EnvironmentLight(self.base.clone().detach())
def clamp_(self, min=None, max=None): self.base.clamp_(min, max)
def get_mip(self, roughness): return torch.where(roughness < self.MAX_ROUGHNESS , (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2) , (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2)
def build_mips(self, cutoff=0.99): self.specular = [self.base] while self.specular[-1].shape[1] > self.LIGHT_MIN_RES: self.specular += [cubemap_mip.apply(self.specular[-1])]
self.diffuse = ru.diffuse_cubemap(self.specular[-1])
for idx in range(len(self.specular) - 1): roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff) self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff)
def regularizer(self): white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0 return torch.mean(torch.abs(self.base - white))
def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True): wo = util.safe_normalize(view_pos - gb_pos)
if specular: roughness = ks[..., 1:2] # y component metallic = ks[..., 2:3] # z component spec_col = (1.0 - metallic)*0.04 + kd * metallic diff_col = kd * (1.0 - metallic) else: diff_col = kd
reflvec = util.safe_normalize(util.reflect(wo, gb_normal)) nrmvec = gb_normal
if self.mtx is not None: # Rotate lookup mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda') reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape) nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape)
^51f458
# Diffuse lookup diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube') shaded_col = diffuse * diff_col
if specular: # Lookup FG term from lookup texture NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4) fg_uv = torch.cat((NdotV, roughness), dim=-1) if not hasattr(self, '_FG_LUT'): self._FG_LUT = torch.as_tensor(np.fromfile('data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda') fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp')
# Roughness adjusted specular env lookup miplevel = self.get_mip(roughness) spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube')
# Compute aggregate lighting reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2] shaded_col += spec * reflectance
return shaded_col * (1.0 - ks[..., 0:1]) # Modulate by hemisphere visibility
#################################################### Load and store#####################################################
# Load from latlong .HDR filedef _load_env_hdr(fn, scale=1.0): latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale cubemap = util.latlong_to_cubemap(latlong_img, [512, 512])
l = EnvironmentLight(cubemap) l.build_mips()
return l
def load_env(fn, scale=1.0): if os.path.splitext(fn)[1].lower() == ".hdr": return _load_env_hdr(fn, scale) else: assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1]
def save_env_map(fn, light): assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently" if isinstance(light, EnvironmentLight): color = util.cubemap_to_latlong(light.base, [512, 1024]) util.save_image_raw(fn, color.detach().cpu().numpy())
################################################## Create trainable env map with random initialization##################################################
def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25): base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias return EnvironmentLight(base)
render.py
import torchimport nvdiffrast.torch as dr
from . import utilfrom . import renderutils as rufrom . import light
# ====================================================# Helper functions# ==================================================def interpolate(attr, rast, attr_idx, rast_db=None): return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
# ======================================================# pixel shader# ====================================================def shade( gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv, view_pos, lgt, material, bsdf ):
################################################ # Texture lookups ################################################### perturbed_nrm = None if 'kd_ks_normal' in material: # Combined texture, used for MLPs because lookups are expensive all_tex_jitter = material['kd_ks_normal'].sample(gb_pos + torch.normal(mean=0, std=0.01, size=gb_pos.shape, device="cuda")) all_tex = material['kd_ks_normal'].sample(gb_pos) assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels" kd, ks, perturbed_nrm = all_tex[..., :-6], all_tex[..., -6:-3], all_tex[..., -3:] # Compute albedo (kd) gradient, used for material regularizer kd_grad = torch.sum(torch.abs(all_tex_jitter[..., :-6] - all_tex[..., :-6]), dim=-1, keepdim=True) / 3 else: kd_jitter = material['kd'].sample(gb_texc + torch.normal(mean=0, std=0.005, size=gb_texc.shape, device="cuda"), gb_texc_deriv) kd = material['kd'].sample(gb_texc, gb_texc_deriv) ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha if 'normal' in material: perturbed_nrm = material['normal'].sample(gb_texc, gb_texc_deriv) kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3
# Separate kd into alpha and color, default alpha = 1 alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1]) kd = kd[..., 0:3]
########################################## # Normal perturbation & normal bend ########################################### if 'no_perturbed_nrm' in material and material['no_perturbed_nrm']: perturbed_nrm = None
gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)
################################################# # Evaluate BSDF ####################################################
assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type" bsdf = material['bsdf'] if bsdf is None else bsdf if bsdf == 'pbr': if isinstance(lgt, light.EnvironmentLight): shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True) else: assert False, "Invalid light type" elif bsdf == 'diffuse': if isinstance(lgt, light.EnvironmentLight): shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False) else: assert False, "Invalid light type" elif bsdf == 'normal': shaded_col = (gb_normal + 1.0)*0.5 elif bsdf == 'tangent': shaded_col = (gb_tangent + 1.0)*0.5 elif bsdf == 'kd': shaded_col = kd elif bsdf == 'ks': shaded_col = ks else: assert False, "Invalid BSDF '%s'" % bsdf
# Return multiple buffers buffers = { 'shaded' : torch.cat((shaded_col, alpha), dim=-1), 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1), 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1) }
^cfa804
return buffers
# ===================================================# Render a depth slice of the mesh (scene), some limitations:# - Single mesh# - Single light# - Single material# ===================================================def render_layer( rast, rast_deriv, mesh, view_pos, lgt, resolution, spp, msaa, bsdf ):
full_res = [resolution[0]*spp, resolution[1]*spp]
############################################# # Rasterize ################################################
# Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution if spp > 1 and msaa: rast_out_s = util.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest') rast_out_deriv_s = util.scale_img_nhwc(rast_deriv, resolution, mag='nearest', min='nearest') * spp else: rast_out_s = rast rast_out_deriv_s = rast_deriv
################################################ # Interpolate attributes #################################################
# Interpolate world space position gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int())
# Compute geometric normals. We need those because of bent normals trick (for bump mapping) v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :] v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :] v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :] face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0, dim=-1)) face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3) gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int())
# Compute tangent space assert mesh.v_nrm is not None and mesh.v_tng is not None gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int()) gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents
# Texture coordinate assert mesh.v_tex is not None gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s)
########################################### # Shade ############################################
buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv, view_pos, lgt, mesh.material, bsdf)
########################################## # Prepare output ############################################
# Scale back up to visibility resolution if using MSAA if spp > 1 and msaa: for key in buffers.keys(): buffers[key] = util.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest')
# Return buffers return buffers
# ==================================================# Render a depth peeled mesh (scene), some limitations:# - Single mesh# - Single light# - Single material# ==================================================def render_mesh( ctx, mesh, mtx_in, view_pos, lgt, resolution, spp = 1, num_layers = 1, msaa = False, background = None, bsdf = None ):
def prepare_input_vector(x): x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x return x[:, None, None, :] if len(x.shape) == 2 else x
def composite_buffer(key, layers, background, antialias): accum = background for buffers, rast in reversed(layers): alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:] accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha) if antialias: accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int()) return accum
assert mesh.t_pos_idx.shape[0] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)" assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1])
full_res = [resolution[0]*spp, resolution[1]*spp]
# Convert numpy arrays to torch tensors mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in view_pos = prepare_input_vector(view_pos)
# clip space transform v_pos_clip = ru.xfm_points(mesh.v_pos[None, ...], mtx_in)
# Render all layers front-to-back layers = [] with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx.int(), full_res) as peeler: for _ in range(num_layers): rast, db = peeler.rasterize_next_layer() layers += [(render_layer(rast, db, mesh, view_pos, lgt, resolution, spp, msaa, bsdf), rast)]
# Setup background if background is not None: if spp > 1: background = util.scale_img_nhwc(background, full_res, mag='nearest', min='nearest') background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1) else: background = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda')
# Composite layers front-to-back out_buffers = {} for key in layers[0][0].keys(): if key == 'shaded': accum = composite_buffer(key, layers, background, True) else: accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), False)
# Downscale to framebuffer resolution. Use avg pooling out_buffers[key] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum
return out_buffers
# ====================================================# Render UVs# ================================================def render_uv(ctx, mesh, resolution, mlp_texture):
# clip space transform uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0
# pad to four component coordinate uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1)
# rasterize rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution)
# Interpolate world space position gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int())
# Sample out textures from MLP all_tex = mlp_texture.sample(gb_pos) assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels" perturbed_nrm = all_tex[..., -3:] return (rast[..., -1:] > 0).float(), all_tex[..., :-6], all_tex[..., -6:-3], util.safe_normalize(perturbed_nrm)
texture.py
import osimport numpy as npimport torchimport nvdiffrast.torch as dr
from . import util
########################################### Smooth pooling / mip computation with linear gradient upscaling##############################################
class texture2d_mip(torch.autograd.Function): @staticmethod def forward(ctx, texture): return util.avg_pool_nhwc(texture, (2,2))
@staticmethod def backward(ctx, dout): gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device="cuda"), torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device="cuda"), indexing='ij') uv = torch.stack((gx, gy), dim=-1) return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp')
^8b6725
##################################################### Simple texture class. A texture can be either# - A 3D tensor (using auto mipmaps)# - A list of 3D tensors (full custom mip hierarchy)####################################################
class Texture2D(torch.nn.Module): # Initializes a texture from image data. # Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays) def __init__(self, init, min_max=None): super(Texture2D, self).__init__()
# 如果初始化的是一个list的图 if isinstance(init, np.ndarray): init = torch.tensor(init, dtype=torch.float32, device='cuda')
elif isinstance(init, list) and len(init) == 1: init = init[0]
if isinstance(init, list): self.data = list(torch.nn.Parameter(mip.clone().detach(), requires_grad=True) for mip in init) elif len(init.shape) == 4: self.data = torch.nn.Parameter(init.clone().detach(), requires_grad=True) elif len(init.shape) == 3: self.data = torch.nn.Parameter(init[None, ...].clone().detach(), requires_grad=True) elif len(init.shape) == 1: self.data = torch.nn.Parameter(init[None, None, None, :].clone().detach(), requires_grad=True) # Convert constant to 1x1 tensor else: assert False, "Invalid texture object"
self.min_max = min_max
# Filtered (trilinear) sample texture at a given location def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'): if isinstance(self.data, list): # 因为多级mipmap是list # 用nvdiffrast的dr.texture采样 out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode) else: # 如果不是list,也就是只有单张贴图,自动生成mipmap if self.data.shape[1] > 1 and self.data.shape[2] > 1: mips = [self.data] while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1: mips += [texture2d_mip.apply(mips[-1])] # 生成下一级mip out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode) else: out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode) return out
def getRes(self): return self.getMips()[0].shape[1:3]
def getChannels(self): return self.getMips()[0].shape[3]
def getMips(self): if isinstance(self.data, list): return self.data else: return [self.data]
# In-place clamp with no derivative to make sure values are in valid range after training def clamp_(self): if self.min_max is not None: for mip in self.getMips(): for i in range(mip.shape[-1]): mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i])
# In-place clamp with no derivative to make sure values are in valid range after training def normalize_(self): with torch.no_grad(): for mip in self.getMips(): mip = util.safe_normalize(mip)
############################################ Helper function to create a trainable texture from a regular texture. The trainable weights are# initialized with texture data as an initial guess###########################################
# 将一张2D图转成可differentiable的pytorch parametersdef create_trainable(init, res=None, auto_mipmaps=True, min_max=None): with torch.no_grad(): if isinstance(init, Texture2D): assert isinstance(init.data, torch.Tensor) min_max = init.min_max if min_max is None else min_max init = init.data elif isinstance(init, np.ndarray): init = torch.tensor(init, dtype=torch.float32, device='cuda')
# Pad to NHWC if needed if len(init.shape) == 1: # Extend constant to NHWC tensor init = init[None, None, None, :] elif len(init.shape) == 3: init = init[None, ...]
# Scale input to desired resolution. if res is not None: init = util.scale_img_nhwc(init, res)
# Genreate custom mipchain if not auto_mipmaps: mip_chain = [init.clone().detach().requires_grad_(True)] while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1: new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)] mip_chain += [util.scale_img_nhwc(mip_chain[-1], new_size)] return Texture2D(mip_chain, min_max=min_max) else: return Texture2D(init, min_max=min_max)
########################################## Convert texture to and from SRGB########################################
def srgb_to_rgb(texture): return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips()))
def rgb_to_srgb(texture): return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips()))
######################################### Utility functions for loading / storing a texture########################################
def _load_mip2D(fn, lambda_fn=None, channels=None): imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda') if channels is not None: imgdata = imgdata[..., 0:channels] if lambda_fn is not None: imgdata = lambda_fn(imgdata) return imgdata.detach().clone()
def load_texture2D(fn, lambda_fn=None, channels=None): base, ext = os.path.splitext(fn) if os.path.exists(base + "_0" + ext): mips = [] while os.path.exists(base + ("_%d" % len(mips)) + ext): mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)] return Texture2D(mips) else: return Texture2D(_load_mip2D(fn, lambda_fn, channels))
def _save_mip2D(fn, mip, mipidx, lambda_fn): if lambda_fn is not None: data = lambda_fn(mip).detach().cpu().numpy() else: data = mip.detach().cpu().numpy()
if mipidx is None: util.save_image(fn, data) else: base, ext = os.path.splitext(fn) util.save_image(base + ("_%d" % mipidx) + ext, data)
def save_texture2D(fn, tex, lambda_fn=None): if isinstance(tex.data, list): for i, mip in enumerate(tex.data): _save_mip2D(fn, mip[0,...], i, lambda_fn) else: _save_mip2D(fn, tex.data[0,...], None, lambda_fn)