Skip to content

code

dataset_mesh.py

import numpy as np
import torch
from render import util
from render import mesh
from render import render
from render import light
from .dataset import Dataset
###############################################################################
# Reference dataset using mesh & rendering
###############################################################################
class DatasetMesh(Dataset):
def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False):
# Init
self.glctx = glctx
self.cam_radius = cam_radius
self.FLAGS = FLAGS
self.validate = validate
self.fovy = np.deg2rad(45)
self.aspect = FLAGS.train_res[1] / FLAGS.train_res[0]
if self.FLAGS.local_rank == 0:
print("DatasetMesh: ref mesh has %d triangles and %d vertices" % (ref_mesh.t_pos_idx.shape[0], ref_mesh.v_pos.shape[0]))
# Sanity test training texture resolution
ref_texture_res = np.maximum(ref_mesh.material['kd'].getRes(), ref_mesh.material['ks'].getRes())
if 'normal' in ref_mesh.material:
ref_texture_res = np.maximum(ref_texture_res, ref_mesh.material['normal'].getRes())
if self.FLAGS.local_rank == 0 and FLAGS.texture_res[0] < ref_texture_res[0] or FLAGS.texture_res[1] < ref_texture_res[1]:
print("---> WARNING: Picked a texture resolution lower than the reference mesh [%d, %d] < [%d, %d]" % (FLAGS.texture_res[0], FLAGS.texture_res[1], ref_texture_res[0], ref_texture_res[1]))
# Load environment map texture
self.envlight = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
self.ref_mesh = mesh.compute_tangents(ref_mesh)
def _rotate_scene(self, itr):
proj_mtx = util.perspective(self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
# Smooth rotation for display.
ang = (itr / 50) * np.pi * 2
mv = util.translate(0, 0, -self.cam_radius) @ (util.rotate_x(0.4) @ util.rotate_y(ang))
mvp = proj_mtx @ mv
campos = torch.linalg.inv(mv)[:3, 3]
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), self.FLAGS.display_res, self.FLAGS.spp
def _random_scene(self):
# ===========================================================
# Setup projection matrix
# ==========================================================
iter_res = self.FLAGS.train_res
proj_mtx = util.perspective(self.fovy, iter_res[1] / iter_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
# ==============================================================
# Random camera & light position
# ==============================================================
# Random rotation/translation matrix for optimization.
mv = util.translate(0, 0, -self.cam_radius) @ util.random_rotation_translation(0.25)
mvp = proj_mtx @ mv
campos = torch.linalg.inv(mv)[:3, 3]
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), iter_res, self.FLAGS.spp # Add batch dimension
def __len__(self):
return 50 if self.validate else (self.FLAGS.iter + 1) * self.FLAGS.batch
def __getitem__(self, itr):
# ===============================================================
# Randomize scene parameters
# ==============================================================
if self.validate:
mv, mvp, campos, iter_res, iter_spp = self._rotate_scene(itr)
else:
mv, mvp, campos, iter_res, iter_spp = self._random_scene()
img = render.render_mesh(self.glctx, self.ref_mesh, mvp, campos, self.envlight, iter_res, spp=iter_spp,
num_layers=self.FLAGS.layers, msaa=True, background=None)['shaded']

^eebb2a

return {
'mv' : mv,
'mvp' : mvp,
'campos' : campos,
'resolution' : iter_res,
'spp' : iter_spp,
'img' : img
}

light.py

import os
import numpy as np
import torch
import nvdiffrast.torch as dr
from . import util
from . import renderutils as ru
##############################################
# Utility functions
#############################################
class cubemap_mip(torch.autograd.Function):
@staticmethod
def forward(ctx, cubemap):
return util.avg_pool_nhwc(cubemap, (2,2))
@staticmethod
def backward(ctx, dout):
res = dout.shape[1] * 2
out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda")
for s in range(6):
gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
indexing='ij')
v = util.safe_normalize(util.cube_to_dir(s, gx, gy))
out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')
return out
###################################################
# Split-sum environment map light source with automatic mipmap generation
###################################################
class EnvironmentLight(torch.nn.Module):
LIGHT_MIN_RES = 16
MIN_ROUGHNESS = 0.08
MAX_ROUGHNESS = 0.5
def __init__(self, base):
super(EnvironmentLight, self).__init__()
self.mtx = None
self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=True)
self.register_parameter('env_base', self.base)
def xfm(self, mtx):
self.mtx = mtx
def clone(self):
return EnvironmentLight(self.base.clone().detach())
def clamp_(self, min=None, max=None):
self.base.clamp_(min, max)
def get_mip(self, roughness):
return torch.where(roughness < self.MAX_ROUGHNESS
, (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2)
, (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2)
def build_mips(self, cutoff=0.99):
self.specular = [self.base]
while self.specular[-1].shape[1] > self.LIGHT_MIN_RES:
self.specular += [cubemap_mip.apply(self.specular[-1])]
self.diffuse = ru.diffuse_cubemap(self.specular[-1])
for idx in range(len(self.specular) - 1):
roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS
self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff)
self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff)
def regularizer(self):
white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0
return torch.mean(torch.abs(self.base - white))
def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True):
wo = util.safe_normalize(view_pos - gb_pos)
if specular:
roughness = ks[..., 1:2] # y component
metallic = ks[..., 2:3] # z component
spec_col = (1.0 - metallic)*0.04 + kd * metallic
diff_col = kd * (1.0 - metallic)
else:
diff_col = kd
reflvec = util.safe_normalize(util.reflect(wo, gb_normal))
nrmvec = gb_normal
if self.mtx is not None: # Rotate lookup
mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda')
reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)
nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape)

^51f458

# Diffuse lookup
diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube')
shaded_col = diffuse * diff_col
if specular:
# Lookup FG term from lookup texture
NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4)
fg_uv = torch.cat((NdotV, roughness), dim=-1)
if not hasattr(self, '_FG_LUT'):
self._FG_LUT = torch.as_tensor(np.fromfile('data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda')
fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp')
# Roughness adjusted specular env lookup
miplevel = self.get_mip(roughness)
spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube')
# Compute aggregate lighting
reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2]
shaded_col += spec * reflectance
return shaded_col * (1.0 - ks[..., 0:1]) # Modulate by hemisphere visibility
###################################################
# Load and store
#####################################################
# Load from latlong .HDR file
def _load_env_hdr(fn, scale=1.0):
latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale
cubemap = util.latlong_to_cubemap(latlong_img, [512, 512])
l = EnvironmentLight(cubemap)
l.build_mips()
return l
def load_env(fn, scale=1.0):
if os.path.splitext(fn)[1].lower() == ".hdr":
return _load_env_hdr(fn, scale)
else:
assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1]
def save_env_map(fn, light):
assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently"
if isinstance(light, EnvironmentLight):
color = util.cubemap_to_latlong(light.base, [512, 1024])
util.save_image_raw(fn, color.detach().cpu().numpy())
#################################################
# Create trainable env map with random initialization
##################################################
def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25):
base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias
return EnvironmentLight(base)

render.py

import torch
import nvdiffrast.torch as dr
from . import util
from . import renderutils as ru
from . import light
# ====================================================
# Helper functions
# ==================================================
def interpolate(attr, rast, attr_idx, rast_db=None):
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
# ======================================================
# pixel shader
# ====================================================
def shade(
gb_pos,
gb_geometric_normal,
gb_normal,
gb_tangent,
gb_texc,
gb_texc_deriv,
view_pos,
lgt,
material,
bsdf
):
################################################
# Texture lookups
###################################################
perturbed_nrm = None
if 'kd_ks_normal' in material:
# Combined texture, used for MLPs because lookups are expensive
all_tex_jitter = material['kd_ks_normal'].sample(gb_pos + torch.normal(mean=0, std=0.01, size=gb_pos.shape, device="cuda"))
all_tex = material['kd_ks_normal'].sample(gb_pos)
assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels"
kd, ks, perturbed_nrm = all_tex[..., :-6], all_tex[..., -6:-3], all_tex[..., -3:]
# Compute albedo (kd) gradient, used for material regularizer
kd_grad = torch.sum(torch.abs(all_tex_jitter[..., :-6] - all_tex[..., :-6]), dim=-1, keepdim=True) / 3
else:
kd_jitter = material['kd'].sample(gb_texc + torch.normal(mean=0, std=0.005, size=gb_texc.shape, device="cuda"), gb_texc_deriv)
kd = material['kd'].sample(gb_texc, gb_texc_deriv)
ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha
if 'normal' in material:
perturbed_nrm = material['normal'].sample(gb_texc, gb_texc_deriv)
kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3
# Separate kd into alpha and color, default alpha = 1
alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1])
kd = kd[..., 0:3]
##########################################
# Normal perturbation & normal bend
###########################################
if 'no_perturbed_nrm' in material and material['no_perturbed_nrm']:
perturbed_nrm = None
gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)
#################################################
# Evaluate BSDF
####################################################
assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type"
bsdf = material['bsdf'] if bsdf is None else bsdf
if bsdf == 'pbr':
if isinstance(lgt, light.EnvironmentLight):
shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True)
else:
assert False, "Invalid light type"
elif bsdf == 'diffuse':
if isinstance(lgt, light.EnvironmentLight):
shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False)
else:
assert False, "Invalid light type"
elif bsdf == 'normal':
shaded_col = (gb_normal + 1.0)*0.5
elif bsdf == 'tangent':
shaded_col = (gb_tangent + 1.0)*0.5
elif bsdf == 'kd':
shaded_col = kd
elif bsdf == 'ks':
shaded_col = ks
else:
assert False, "Invalid BSDF '%s'" % bsdf
# Return multiple buffers
buffers = {
'shaded' : torch.cat((shaded_col, alpha), dim=-1),
'kd_grad' : torch.cat((kd_grad, alpha), dim=-1),
'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1)
}

^cfa804

return buffers
# ===================================================
# Render a depth slice of the mesh (scene), some limitations:
# - Single mesh
# - Single light
# - Single material
# ===================================================
def render_layer(
rast,
rast_deriv,
mesh,
view_pos,
lgt,
resolution,
spp,
msaa,
bsdf
):
full_res = [resolution[0]*spp, resolution[1]*spp]
#############################################
# Rasterize
################################################
# Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution
if spp > 1 and msaa:
rast_out_s = util.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest')
rast_out_deriv_s = util.scale_img_nhwc(rast_deriv, resolution, mag='nearest', min='nearest') * spp
else:
rast_out_s = rast
rast_out_deriv_s = rast_deriv
################################################
# Interpolate attributes
#################################################
# Interpolate world space position
gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int())
# Compute geometric normals. We need those because of bent normals trick (for bump mapping)
v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :]
v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :]
v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :]
face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0, dim=-1))
face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int())
# Compute tangent space
assert mesh.v_nrm is not None and mesh.v_tng is not None
gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int())
gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents
# Texture coordinate
assert mesh.v_tex is not None
gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s)
###########################################
# Shade
############################################
buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv,
view_pos, lgt, mesh.material, bsdf)
##########################################
# Prepare output
############################################
# Scale back up to visibility resolution if using MSAA
if spp > 1 and msaa:
for key in buffers.keys():
buffers[key] = util.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest')
# Return buffers
return buffers
# ==================================================
# Render a depth peeled mesh (scene), some limitations:
# - Single mesh
# - Single light
# - Single material
# ==================================================
def render_mesh(
ctx,
mesh,
mtx_in,
view_pos,
lgt,
resolution,
spp = 1,
num_layers = 1,
msaa = False,
background = None,
bsdf = None
):
def prepare_input_vector(x):
x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x
return x[:, None, None, :] if len(x.shape) == 2 else x
def composite_buffer(key, layers, background, antialias):
accum = background
for buffers, rast in reversed(layers):
alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:]
accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha)
if antialias:
accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int())
return accum
assert mesh.t_pos_idx.shape[0] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)"
assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1])
full_res = [resolution[0]*spp, resolution[1]*spp]
# Convert numpy arrays to torch tensors
mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in
view_pos = prepare_input_vector(view_pos)
# clip space transform
v_pos_clip = ru.xfm_points(mesh.v_pos[None, ...], mtx_in)
# Render all layers front-to-back
layers = []
with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx.int(), full_res) as peeler:
for _ in range(num_layers):
rast, db = peeler.rasterize_next_layer()
layers += [(render_layer(rast, db, mesh, view_pos, lgt, resolution, spp, msaa, bsdf), rast)]
# Setup background
if background is not None:
if spp > 1:
background = util.scale_img_nhwc(background, full_res, mag='nearest', min='nearest')
background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1)
else:
background = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda')
# Composite layers front-to-back
out_buffers = {}
for key in layers[0][0].keys():
if key == 'shaded':
accum = composite_buffer(key, layers, background, True)
else:
accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), False)
# Downscale to framebuffer resolution. Use avg pooling
out_buffers[key] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum
return out_buffers
# ====================================================
# Render UVs
# ================================================
def render_uv(ctx, mesh, resolution, mlp_texture):
# clip space transform
uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0
# pad to four component coordinate
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1)
# rasterize
rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution)
# Interpolate world space position
gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int())
# Sample out textures from MLP
all_tex = mlp_texture.sample(gb_pos)
assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels"
perturbed_nrm = all_tex[..., -3:]
return (rast[..., -1:] > 0).float(), all_tex[..., :-6], all_tex[..., -6:-3], util.safe_normalize(perturbed_nrm)

texture.py

import os
import numpy as np
import torch
import nvdiffrast.torch as dr
from . import util
##########################################
# Smooth pooling / mip computation with linear gradient upscaling
##############################################
class texture2d_mip(torch.autograd.Function):
@staticmethod
def forward(ctx, texture):
return util.avg_pool_nhwc(texture, (2,2))
@staticmethod
def backward(ctx, dout):
gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device="cuda"),
torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device="cuda"),
indexing='ij')
uv = torch.stack((gx, gy), dim=-1)
return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp')

^8b6725

####################################################
# Simple texture class. A texture can be either
# - A 3D tensor (using auto mipmaps)
# - A list of 3D tensors (full custom mip hierarchy)
####################################################
class Texture2D(torch.nn.Module):
# Initializes a texture from image data.
# Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays)
def __init__(self, init, min_max=None):
super(Texture2D, self).__init__()
# 如果初始化的是一个list的图
if isinstance(init, np.ndarray):
init = torch.tensor(init, dtype=torch.float32, device='cuda')
elif isinstance(init, list) and len(init) == 1:
init = init[0]
if isinstance(init, list):
self.data = list(torch.nn.Parameter(mip.clone().detach(), requires_grad=True) for mip in init)
elif len(init.shape) == 4:
self.data = torch.nn.Parameter(init.clone().detach(), requires_grad=True)
elif len(init.shape) == 3:
self.data = torch.nn.Parameter(init[None, ...].clone().detach(), requires_grad=True)
elif len(init.shape) == 1:
self.data = torch.nn.Parameter(init[None, None, None, :].clone().detach(), requires_grad=True) # Convert constant to 1x1 tensor
else:
assert False, "Invalid texture object"
self.min_max = min_max
# Filtered (trilinear) sample texture at a given location
def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'):
if isinstance(self.data, list):
# 因为多级mipmap是list
# 用nvdiffrast的dr.texture采样
out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode)
else:
# 如果不是list,也就是只有单张贴图,自动生成mipmap
if self.data.shape[1] > 1 and self.data.shape[2] > 1:
mips = [self.data]
while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1:
mips += [texture2d_mip.apply(mips[-1])] # 生成下一级mip
out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode)
else:
out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode)
return out
def getRes(self):
return self.getMips()[0].shape[1:3]
def getChannels(self):
return self.getMips()[0].shape[3]
def getMips(self):
if isinstance(self.data, list):
return self.data
else:
return [self.data]
# In-place clamp with no derivative to make sure values are in valid range after training
def clamp_(self):
if self.min_max is not None:
for mip in self.getMips():
for i in range(mip.shape[-1]):
mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i])
# In-place clamp with no derivative to make sure values are in valid range after training
def normalize_(self):
with torch.no_grad():
for mip in self.getMips():
mip = util.safe_normalize(mip)
###########################################
# Helper function to create a trainable texture from a regular texture. The trainable weights are
# initialized with texture data as an initial guess
###########################################
# 将一张2D图转成可differentiable的pytorch parameters
def create_trainable(init, res=None, auto_mipmaps=True, min_max=None):
with torch.no_grad():
if isinstance(init, Texture2D):
assert isinstance(init.data, torch.Tensor)
min_max = init.min_max if min_max is None else min_max
init = init.data
elif isinstance(init, np.ndarray):
init = torch.tensor(init, dtype=torch.float32, device='cuda')
# Pad to NHWC if needed
if len(init.shape) == 1: # Extend constant to NHWC tensor
init = init[None, None, None, :]
elif len(init.shape) == 3:
init = init[None, ...]
# Scale input to desired resolution.
if res is not None:
init = util.scale_img_nhwc(init, res)
# Genreate custom mipchain
if not auto_mipmaps:
mip_chain = [init.clone().detach().requires_grad_(True)]
while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1:
new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)]
mip_chain += [util.scale_img_nhwc(mip_chain[-1], new_size)]
return Texture2D(mip_chain, min_max=min_max)
else:
return Texture2D(init, min_max=min_max)
#########################################
# Convert texture to and from SRGB
########################################
def srgb_to_rgb(texture):
return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips()))
def rgb_to_srgb(texture):
return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips()))
########################################
# Utility functions for loading / storing a texture
########################################
def _load_mip2D(fn, lambda_fn=None, channels=None):
imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')
if channels is not None:
imgdata = imgdata[..., 0:channels]
if lambda_fn is not None:
imgdata = lambda_fn(imgdata)
return imgdata.detach().clone()
def load_texture2D(fn, lambda_fn=None, channels=None):
base, ext = os.path.splitext(fn)
if os.path.exists(base + "_0" + ext):
mips = []
while os.path.exists(base + ("_%d" % len(mips)) + ext):
mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)]
return Texture2D(mips)
else:
return Texture2D(_load_mip2D(fn, lambda_fn, channels))
def _save_mip2D(fn, mip, mipidx, lambda_fn):
if lambda_fn is not None:
data = lambda_fn(mip).detach().cpu().numpy()
else:
data = mip.detach().cpu().numpy()
if mipidx is None:
util.save_image(fn, data)
else:
base, ext = os.path.splitext(fn)
util.save_image(base + ("_%d" % mipidx) + ext, data)
def save_texture2D(fn, tex, lambda_fn=None):
if isinstance(tex.data, list):
for i, mip in enumerate(tex.data):
_save_mip2D(fn, mip[0,...], i, lambda_fn)
else:
_save_mip2D(fn, tex.data[0,...], None, lambda_fn)