Train a diffusion model
How to train a UNet2DModel from scratch on a subset of …dataset
要对应看jupyter notebook,这篇post中只记录记忆点的code
Dataset
dataclass
configuration train
datasets
load_dataset
Configuration
利用装饰器快速生成类的初始化方法
@dataclassclass TrainConfig: image_size = 128 #the generated image resolution train_batch_size =16 eval_batch_size = 16 num_epochs = 50 gradient_accumulation_steps = 1 learning_rate = 1e-4 save_image_epochs = 10 save_model_epochs = 30 mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision output_dir = "ddpm-butterflies-128" # the model name locally and on the HF Hub push_to_hub = True # whether to upload the saved model to the HF Hub hub_model_id = "<your-username>/<my-awesome-model>" # the name of the repository to create on the HF Hub hub_private_repo = None overwrite_output_dir = True # overwrite the old model when re-running the notebook seed = 0
config = TrainConfig()
Load and Visualize the data
#from datasets import load_datasetconfig.dataset_name = "huggan/smithsonian_butterflies_subset"dataset = load_dataset(config.dataset_name, split="train")
fig, axs = plt.subplots(1, 4, figsize=(16, 4))for i, image in enumerate(dataset[:4]["image"]): axs[i].imshow(image) axs[i].set_axis_off()
fig.show()
- 注意:
dataset[:4]["image"]
这里的后面的”image”,对应的是数据集里面的类别,可以打开数据看:The dataset - for i, image in …: 对应的是index and value
Data pre-process
数据预处理:使用torchvision的transforms对图像进行预处理
#from torchvision import transformspreprocess = transforms.Compose( [ transforms.Resize( (config.image_size, config.image_size) ), transforms.RandomHorizontalFlip(), transforms.ToTensor (), transforms.Normalize( [0.5] , [0.5]), ])
def transform(examples): images = [preprocess(image.convert("RGB")) for image in examples["image"]] return {"images" : images}
dataset.set_transform(transform)train_dataloader = torch.utils.data.DataLoader(dataset, batch_size = config.train_batch_size, shuffle = True)
- transform.compose是把多个预处理组合起来的
- 先将图片重新整理大小,和config的image_size对上
- 随机水平翻转
- transfer image to tensor, for PyTorch to calculate
- normalize tensor, 常用于加速模型训练。这里均值和标准差都设为 0.5,将像素值范围从 [0, 1] 映射到 [-1, 1]。
Issues:
When after loading matplotlib, kernel crash
sometimes the version between libraries not match, upgrade the library to solve the issue
solution:
pip install --upgrade matplotlib numpy
Two modes in PyTorch: eval() and train()
- eval(): turn off Dropouts layers, BatchNorm layers
- another evaluation/validation is using
torch.no_grad()
in pair withmodel.eval()
to turn off gradients computation - Purpose: set to eval() because you want ot train part of the parameters ?
- another evaluation/validation is using
latent_scale_factor : VAE latent space
it defines the pixels space and VAE latent space ratio relationship
- if latent_scale_factor = 8, means you have
512*512
image, and it will be64*64
in latent space
Dataset的处理
- it is recommend using Dataset from PyTorch (use it as super class)
1. 数据准备
-
加载数据:首先你需要准备好训练用的图像数据,通常会把图像读入内存或者利用数据加载器(例如 PyTorch 的 DataLoader)。
-
转换为 Tensor:图像需要转换成 tensor,并根据模型要求进行归一化或其他预处理操作。
2. 模型组件和冻结参数
-
选择需要微调的部分:如果你只希望调整 UNet,而其他部分(例如 VAE 或文本编码器)保持不变,可以在加载预训练模型后,将其他模块的参数冻结(即不计算梯度,不更新参数)。
-
冻结方法:通常对不需要更新的部分调用
requires_grad=False
。
3. 处理文本提示(Tokenizer)
-
文本条件:对于像 Stable Diffusion 这样的文本到图像生成模型,文本条件很重要。如果你的任务不依赖于具体的文本提示,你可以设计一个“空白”或固定的提示。但通常建议还是给出合理的文本描述,因为这会影响模型生成的结果。
-
Tokenizer 的作用:Tokenizer 的作用是把文本转换成模型能够理解的 token 序列。即使你用的是一个固定提示,也需要保证输入形式符合模型要求。
4. 定义损失函数
-
噪声预测损失:在 Diffusion 模型中,常见的做法是在潜空间中加入噪声,让模型学习从噪声中恢复原始数据,常用的损失是均方误差(MSE)。
-
只计算 UNet 的损失:由于其他部分被冻结,所以梯度只会在 UNet 上进行传播,损失计算过程和标准训练类似。
5. 训练循环与 Epoch 概念
-
Epoch 的定义:一个 epoch 指的是模型已经遍历了一次整个训练数据集。例如,如果数据集中有 10,000 张图片,经过一次完整的前向和反向传播后,这个过程就称为一个 epoch。
-
训练循环内部主要操作:
-
批次处理:数据集被划分成若干个 batch。在每个 batch 中,你会:
-
从数据加载器中读取一批图像和对应的文本提示(如果有的话)。
-
对图像进行必要的预处理,转换成 tensor。(同时做normalization等预处理)
-
对文本进行 token 化(如果需要)。
-
-
前向传播:将 batch 数据输入到模型中,通常包括添加噪声、根据当前时间步和噪声进行 UNet 的预测。
-
计算损失:根据预测和真实的噪声计算损失。
-
反向传播和参数更新:反向传播梯度并仅更新 UNet 的参数,其他参数因被冻结而不变。
-
记录和监控:保存损失值,定期保存模型 checkpoint,以便观察训练效果或防止训练中断时丢失进度。
-
图像 PNG → tensor(预处理) → latent(编码) → tensor(扩散/处理) → latent(反扩散) → tensor(解码后处理) → PNG
扩散过程是在latent空间里面做的