记录下Compressai库的学习流程(基于Gemini和开源代码)
模型整体架构
我们的学习遵循自顶向下的模式,先看整体的结构设计,结构树形图如下:
```text
compressai/
│
├── models/ # 模型总装车间 (职责不变)
│ ├── base.py # - 核心基类 `CompressionModel`
│ ├── google.py # - Ballé, Minnen, Cheng 等经典图像压缩模型
│ └── video/ # - 新增:视频压缩模型子目录
│ └── google.py # - SSF, DCVC 等视频压缩模型实现
│
├── latent_codecs/ # 隐层编解码器 (核心升级)
│ ├── base.py # - `LatentCodec` 基类,定义统一接口
│ ├── entropy_bottleneck.py # - 全分解模型 (Factorized Prior)
│ ├── gaussian_conditional.py # - 高斯条件模型
│ └── hyperprior.py # - 将超先验中 z 的处理逻辑封装成一个 LatentCodec
│
├── layers/ # 自定义神经网络层 (职责不变)
│ └── gdn.py # - GDN / IGDN 实现
│
├── entropy_coding/ # 熵编码抽象层 (新!)
│ ├── __init__.py # - 封装了对不同后端 (ans, rangecoder) 的调用
│ ├── ans.py # - Python 端对 C++ ANS 编码器的封装
│ └── ...
│
├── cpp_exts/ # C++ 性能扩展 (底层核心)
│ └── rans/ # - rANS (range Asymmetric Numeral Systems) 的 C++ 源码
│
├── optimizers/ # 优化器与参数化
│ └── parametrizers.py # - 例如 LowerBound,用于确保 GDN 参数为正
│
├── losses/ # 损失函数
│ └── rate_distortion.py # - 核心的率-失真损失 `RateDistortionLoss`
│
└── zoo/ # 预训练模型库
└── __init__.py # - `load_state_dict_from_url` 加载网络权重的核心逻辑
通过抽象基类定义统一接口,利用多态性和动态绑定机制,实现调用方与具体实现方的解耦,从而构建一个可扩展、可维护的系统。
**models/base.py**实现抽象基类
entropy_coding/: 明确它是对底层的 C++ ans 或 range_coder 的 Pythonic 抽象,统一了 encode / decode 接口。
latent_codecs/: 强调它实现了量化、概率建模、熵编码的三位一体。
重构的目的:解耦与多态
- 单一职责原则 (SRP):主模型只负责特征提取($g_a$)和重建($g_s$)。
- 黑盒化:主模型只需要问
latent_codec:“给你这个张量 $y$,请帮我把它变成比特流,并告诉我训练时的 $y_{hat}$ 和概率。”
剖析完整数据流动
使用核心训练脚本 examples/train.py追踪一次完整的数据流动:从原始图片到压缩比特流,再到损失计算与模型优化的全过程。
1.参数解析 (argparse)
脚本的 main 函数首先通过 Python 的 argparse 模块定义了所有可配置的超参数。几个关键参数包括:
--model: 选择要训练的模型,例如bmshj2018-hyperprior。--dataset: 指定训练用的数据集路径。--lambda: 率-失真权衡因子λ。(在损失项中乘在Distortion前)这是最重要的参数之一,一个高λ会产出高码率高质量的模型,反之亦然。--learning-rate: 学习率。--batch-size: 批处理大小。--patch-size: 训练时从原图随机裁剪的图块大小。
2.核心组件的实例化
参数确定后,脚本开始实例化训练所需的四大核心组件:
-
数据加载器 (
Dataloader):- 代码溯源 (
mainL190-L208): 脚本根据数据集路径创建ImageFolder实例 (compressai/datasets/image.py),并应用transforms(如随机裁剪)。最后用 PyTorch 的DataLoader进行封装。 - 作用:负责从磁盘读取图片,进行预处理,并按
batch_size组装成一批批的张量,送入训练循环。
- 代码溯源 (
-
压缩模型 (
CompressionModel):- 代码溯源 (
mainL211):net = models[args.model](**vars(args)) - 解读:这行代码非常灵活。它利用
compressai.models注册表,通过命令行传入的字符串名字 (args.model),直接从models包中找到对应的模型类并实例化。这再次体现了框架良好的可扩展性。
- 代码溯源 (
-
优化器 (
Optimizer) 和 学习率调度器 (Scheduler):- 代码溯源 (
mainL218-L232): 脚本实例化了configure_optimizers函数返回的优化器和调度器。通常,作者会为主网络参数和LatentCodec中的量化参数设置不同的学习率。 - 作用:优化器负责根据损失函数的梯度更新模型参数,而调度器则负责在训练过程中动态调整学习率。
- 代码溯源 (
-
损失函数 (
Criterion):- 代码溯源 (
mainL235):criterion = RateDistortionLoss(lmbda=args.lmbda) - 解读:这里实例化了位于
compressai/losses/rate_distortion.py的核心损失函数RateDistortionLoss。它的唯一关键参数就是λ,完美体现了率-失真优化的核心思想。
- 代码溯源 (
3.train_one_epoch 中的数据流动
train_one_epoch 函数 是整个训练过程的核心。让我们跟随一批数据 d (一个 batch 的图像张量),看看它经历的完整旅程:
输入 (d): d 是一个形状为 [B, C, H, W] 的张量,代表 B 张图像。
-
代码:
out_net = net(d) -
内部流程: a.
d首先进入模型的g_a(分析变换),被编码成隐层表示y。 b.y接着被送入latent_codec。在训练阶段(forward方法),latent_codec对y添加均匀噪声模拟量化得到y_hat,并同时估计出y_hat中每个元素的概率likelihoods。 c.y_hat被送入模型的g_s(合成变换),解码出重建图像x_hat。 -
输出 (
out_net):net(d)的返回值是一个字典,包含{ "x_hat": x_hat, "likelihoods": likelihoods }。这个精心设计的数据结构,刚好是RateDistortionLoss所需的全部信息。 -
代码:
out_criterion = criterion(out_net, d) -
内部流程 (
RateDistortionLoss.forward): a. 失真 (Distortion): 计算out_net["x_hat"](重建图像) 和d(原始图像) 之间的差异,通常是MSE、也可以换成ssim或者主观指标如LPIPS。 b. 率 (Rate): 根据信息论,率是信息量的期望。这里通过对out_net["likelihoods"]取负对数并求和,精确地计算出压缩y_hat所需的比特数(bpp_loss)。 c. 总损失: 最终损失为loss = args.lmbda * distortion_loss + rate_loss。 -
输出 (
out_criterion): 返回一个字典,包含{ "loss": loss, "mse_loss": mse_loss, "bpp_loss": bpp_loss }。 -
代码:
out_criterion["loss"].backward() clip_max_norm(net.parameters(), 1.0) # 梯度裁剪,防止梯度爆炸 optimizer.step() -
流程:
train.py调用model(x)。
model.forward(x):- 调用
self.g_a(x)得到y。 - 调用
self.latent_codec(y)。
latent_codec.forward(y):- 如果是
Hyperprior结构,它会先调用self.entropy_bottleneck处理 $z$。 - 然后利用 $z$ 解码出的信息作为 $y$ 的先验(Prior)。
- 返回
y_hat和likelihoods。
train.py拿到out_net,调用criterion(out_net, x)。RateDistortionLoss:- 按照前文解析的逻辑,根据
likelihoods算出 $R$,根据x_hat算出 $D$。 - 返回最终
loss。
loss.backward():- 梯度从
Loss反向流经g_s->latent_codec->g_a。
手写Latent Codec
class AutoregressiveLatentCodec(LatentCodec):(继承基类)
"""
一个简单的自回归隐层编解码器实现
职责:管理量化、概率预测(上下文模型)以及熵编码
"""
def __init__(self, latent_dim, **kwargs):
super().__init__(**kwargs)
# 上下文预测网络:通过 y_hat 预测概率参数 (mu, sigma)
self.context_prediction = MaskedConv2d(
latent_dim, latent_dim * 2, kernel_size=5, padding=2, stride=1
)
# 熵模型:高斯条件熵模型
self.gaussian_conditional = GaussianConditional(None)
def forward(self, y):
# 训练阶段:利用 MaskedConv 的并行性
# y: 原始隐层张量 [B, C, H, W]
# 1. 模拟量化
y_hat = self.quantize(y, mode="noise" if self.training else "dequantize")
# 2. 预测概率分布参数
# 掩模卷积保证了当前像素只能“看到”之前扫描到的像素
params = self.context_prediction(y_hat)
scales, means = params.chunk(2, 1)
# 3. 计算似然概率(用于计算 Rate)
_, likelihoods = self.gaussian_conditional(y, scales, means=means)
return {
"y_hat": y_hat,
"likelihoods": {"y": likelihoods}
}
def compress(self, y):
# y: [1, C, H, W]
# 1. 预先量化 (推理阶段必须是确定的 round)
y_hat = torch.round(y)
# 2. 准备状态:对于自回归,我们需要逐像素(或逐块)处理
# 在这个例子中,我们假设使用最经典的 Raster-scan (光栅扫描) 顺序
B, C, H, W = y_hat.shape
# 这里的关键是:虽然 y_hat 已知,但为了模拟解码端的行为,
# 我们通常需要维持一个上下文缓存
# 下面是一个抽象逻辑,CompressAI 内部通常会优化这一过程
# 3. 概率估计:在自回归下,每个像素点的概率取决于之前的像素
# 由于 compress 时 y_hat 全已知,我们可以一次性通过 MaskedConv 得到所有均值和方差
params = self.context_prediction(y_hat)
scales, means = params.chunk(2, 1)
# 4. 调用 C++ 后端进行编码
# 注意:这里的编码器(ans)会将隐层打成字符串
# y_hat: 实际符号, scales/means: 每个符号对应的概率模型参数
cdfs = self.gaussian_conditional.get_cdf(scales, means)
indices = self.gaussian_conditional.build_indexes(scales)
stream = self.entropy_coder.encode_with_indexes(
y_hat.view(-1).int(), indices.view(-1).int(), cdfs, ...
)
return stream
def decompress(self, strings, shape):
# shape: (H, W) 隐层分辨率
H, W = shape
y_hat = torch.zeros((1, self.latent_dim, H, W)).to(device)
# 🔓 核心痛点:必须写循环!
# 因为 MaskedConv 的输出取决于 y_hat 中已经填充的部分
for h in range(H):
for w in range(W):
# 1. 提取当前已解码部分,预测当前位置的分布参数
# 这是一个典型的“因果预测”
params = self.context_prediction(y_hat)
scales, means = params.chunk(2, 1)
# 2. 只取当前 (h, w) 位置的参数去解码一个符号
s = scales[:, :, h:h+1, w:w+1]
m = means[:, :, h:h+1, w:w+1]
# 3. 从比特流中解码出当前像素值
rv = self.gaussian_conditional.decode_with_indexes(...)
# 4. 更新 y_hat,供下一个位置的预测使用
y_hat[:, :, h, w] = rv + m
return y_hat
写在训模型之外(核心算子)
先聊算术编码
核心思想是将整个消息序列映射为 $[0, 1)$ 区间上的一个实数。
-
原理:根据符号出现的概率 $P(s)$,将当前区间不断细分。
-
数学表示:假设当前区间为 $[L, R)$,符号 $s$ 占据子区间 $[C_s, C_s + f_s)$,则更新规则为:
$$ L_{new} = L + (R - L) \cdot C_s $$
$$ R_{new} = L + (R - L) \cdot (C_s + f_s) $$
但是,尽管 AC 在理论上无限接近香农极限,但它在高性能端到端压缩(如 CompressAI)中存在致命伤:
- 精度爆炸(The Precision Crisis):随着序列变长,区间变得无限小,必须使用高精度浮点数或复杂的重正规化,计算开销巨大。
- 串行瓶颈:AC 是严格 FIFO(先进先出)的,且编码每一位都需要进行复杂的乘除法,难以利用 CPU/GPU 的并行特性。
- 状态耦合:编码下一位必须等待上一位区间更新完成。
AC到rANS
将 AC 的“区间收缩”变成了一个整数状态 $x$ 的累加
rANS 的核心思想是将整个消息序列编码为一个单一的正整数 $x$。这个 $x$ 即是我们的“编码状态”。
假设我们的隐层符号(Symbols)服从离散分布,字母表为 $\mathcal{A}$:
- $f_s$:符号 $s \in \mathcal{A}$ 出现的频次(Frequency)。
- $M = \sum_{s \in \mathcal{A}} f_s$:总频次,通常取 $2^L$(如 $2^{16}$),代表概率建模的精度。
- $C_s = \sum_{i < s} f_i$:累积频次(CDF)。
- $x$:当前的编码器状态(State)。
编码过程是将当前状态 $x$ 和新符号 $s$ 映射为新状态 $x’$。数学公式如下:
$$x’ = C(x, s) = \lfloor \frac{x}{f_s} \rfloor \cdot M + C_s + (x \pmod{f_s})$$
- $\lfloor \frac{x}{f_s} \rfloor \cdot M$:为新信息腾出空间,跨度为 $M$。
- $C_s$:定位到该符号在概率区间中的起始位置。
- $(x \pmod{f_s})$:将旧状态 $x$ 的“残余信息”编码进当前符号占用的频率区间内。
解码是编码的逆过程。给定状态 $x’$,我们首先要识别出它对应哪个符号 $s$:
-
找出符号:找到满足 $C_s \leq (x’ \pmod M) < C_{s+1}$ 的 $s$。
-
还原状态:
$$ x = D(x’, s) = f_s \cdot \lfloor \frac{x’}{M} \rfloor + (x’ \pmod M) - C_s $$
补充:计算效率分析
rANS 的编码公式:
$$ x’ = \underbrace{\lfloor x / f_s \rfloor \cdot M}{\text{项 A}} + \underbrace{C_s}{\text{项 B}} + \underbrace{(x \pmod{f_s})}_{\text{项 C}} $$ 在 C++ 实现中,这行代码的执行逻辑如下:
- 除法 (
/):计算q = x / f_s。这一步同时得到了商q和 余数r = x % f_s。这是 CPU 的一条指令(DIV)。 - 位移 (
<< L):因为 $M$ 通常设为 $2^{16}$,所以q * M变成了q << 16。 - 加法 (
+):将(q << 16)、C_s和r相加。
结论:对于 CPU 来说,这确实就是 1 次除法 + 几次位运算/加法。相比 AC 复杂的区间缩放和多次判断,速度提升了几个数量级。
解码:$s = \text{find_s}(x \pmod M)$$
$$x_{next} = f_s \cdot \lfloor x / M \rfloor + (x \pmod M) - C_s$$
- 查找 (
find_s):由于 $M=2^{16}$,x % M就是x & 0xFFFF。我们拿着这个 16 位的索引去 CDF 表里查,一步就能定位到符号 $s$。这就是你说的一次查找。 - 乘法 (
\*):公式中的 $\lfloor x / M \rfloor$ 是x >> 16。剩下的操作就是f_s * (x >> 16) + offset。 - 结论:解码只需 1 次查表 + 1 次乘法。
struct AnsSymbol {
uint32_t start; // C_s (累积频率)
uint32_t freq; // f_s (符号频率)
// --- 工业级优化项 ---
uint32_t rcp_freq; // 频率的倒数定点数:(1ULL << 31) / freq
uint32_t rcp_shift; // 用于位移的偏移量
};
// 对应公式:x' = (x/f)*M + Cs + (x%f)
void RansEncPut(uint32_t* x, uint16_t** pptr, const AnsSymbol& sym) {
uint32_t x_curr = *x;
// 1. 重正规化 (Renormalization):防止状态 x 溢出 32 位
// 工业库通常会在 x 达到一个阈值时,吐出 16 bits 到比特流
uint32_t x_max = ((RANS_L >> 16) << 16) / sym.freq;
while (x_curr >= x_max) {
*(*pptr)-- = (uint16_t)(x_curr & 0xffff); // 写入比特流(逆序)
x_curr >>= 16;
}
// 2. 核心状态转移:消灭除法
// 理论公式:q = x / sym.freq; r = x % sym.freq;
// 工业优化:使用预计算的倒数进行定点数乘法位移
uint32_t q = (uint32_t)(((uint64_t)x_curr * sym.rcp_freq) >> sym.rcp_shift);
uint32_t r = x_curr - q * sym.freq;
// 3. 更新状态:M = 1 << 16
*x = (q << 16) + sym.start + r;
}
// 对应公式:x = f * floor(x/M) + (x%M) - Cs
uint32_t RansDecGet(uint32_t* x, uint16_t** pptr, const AnsSymbol* sym_table) {
uint32_t x_curr = *x;
// 1. 获取当前插槽位置 (slot = x % M)
uint32_t slot = x_curr & 0xffff;
// 2. 一次查找获取符号及其元数据
const AnsSymbol& sym = sym_table[slot];
// 3. 还原状态:x = f * (x >> 16) + slot - Cs
x_curr = sym.freq * (x_curr >> 16) + slot - sym.start;
// 4. 重正规化(回填):从比特流中读取 16 bits 恢复状态
while (x_curr < RANS_L) {
x_curr = (x_curr << 16) | *(*pptr)++;
}
*x = x_curr;
return slot; // 返回解码出的值
}
加餐1:聊一聊熵模型与高斯假设
熵建模的本质:从概率到码率
在压缩系统中,**熵模型(Entropy Model)**的任务是预测 Latent $y$ 的概率分布 $q(y)$。
-
核心逻辑: 码率 $R$ 取决于观测值 $y$ 在预测分布中的概率。公式表达为 $R \approx -\log_2 q(y)$。
-
数学建模(单高斯假设):
我们将码率拆解为两个部分:
$$ R(y) \propto \underbrace{\ln(\sigma)}{\text{基础开销(分布平坦度)}} + \underbrace{\frac{(y - \mu)^2}{2\sigma^2}}{\text{偏差代价(预测准确度)}} $$
- $\mu$ 的作用: 决定了预测的中心。$\mu$ 越准,偏差项越小。
- $\sigma$ 的作用: 决定了容错率。$\sigma$ 越小,虽然基础开销低,但对预测偏差的惩罚极度敏感。
码流中传输的到底是什么?
这是一个关于“空间切分”的物理过程,而非简单的数值传输。
- 不是残差,也不是原值: 码流传输的是一个区间指针(通常是一个二进制小数)。
- 编解码同步: 编码端和解码端利用相同的超先验信息,在本地生成完全一致 $\mu$ 和 $\sigma$。
- 算术编码机制: 编码端根据 $y$ 在地图上的位置确定区间并输出码流;解码端根据码流指向的区间,从地图上反查出 $y$。
- 条件编码 vs. 残差编码: 条件编码通过动态调整 $\mu$ 和 $\sigma$,在数学上包含了残差编码,并能处理非平稳的方差变化,因此性能上限更高。
为了应对真实图像中复杂的、非高斯的 Latent 分布,GMM 引入了“多重押注”机制。
- 数学表达: $p(y) = \sum_{k=1}^K \pi_k \mathcal{N}(y | \mu_k, \sigma_k^2)$。
- 直觉理解: 模型不再只给出一个预测中心,而是给出多个可能的中心($\mu_k$)及其置信度($\pi_k$)。
- 优势: 极大地增强了对异方差性(Heteroscedasticity)和多峰分布的拟合能力,只要 $y$ 落在任何一个高斯分量的中心附近,码率就不会爆炸。
加餐2:真实的编码落地
均值归一化 (Zero-Mean Normalization)
图中的 Normalization 步骤是整个工程优化的精髓。
- 数学表达:送入 rANS 的符号不是 $y$,而是 $v = \text{round}(y - \mu)$。
- 物理意义:这一步将无数种可能的均值 $\mu$ 全部平移到了坐标原点。这使得我们在
Update阶段只需要针对 $\sigma$(分布的胖瘦)建表,而不需要管 $\mu$(分布的位置)。
编码器内的“受损视角”同步
观察图中 Encoder 内部的 Z_string -> Z_dec 链路。
- 强制要求:编码器必须使用量化后的 $\hat{z}$ 来生成 $\mu$ 和 $\sigma$。
- 视角:这是为了确保编码器看到的概率分布与解码器拿到的完全一致。如果你跳过这一步直接用原始的 $z$,由于 $z \neq \hat{z}$,推导出的 $\sigma$ 索引就会跳变,直接导致解码端的 rANS 状态机崩溃,产生彩色旋涡。
Update 产生的 CDF 表是“物理契约”
Update 阶段生成的 CDF_Table 是跨平台、跨设备编解码的唯一准则。
- 误差定量:由于
Scale Table只有 64 档,任何落在两档之间的 $\sigma$ 都会被强行归类到最近的一档。这种 Scale Discretization 会带来约 $1%$ 的码率损失(KL 散度开销),但它是换取跨平台确定性的必要牺牲。 - 16-bit 约束:CDF 的值域被限制在 $[0, 65536]$,这是为了适配高性能 C++ rANS 实现中的位移运算。
码率惩罚项的“愿景”与“实操”
- 训练时:Loss 里的 $R$ 是基于公式 $P(v) = \Phi(v+0.5) - \Phi(v-0.5)$ 计算的连续积分。
- 推理时:码率是 rANS 根据
CDF_Table的区间宽度,实际写出的二进制位。 - 对齐:只要 $h_s$ 网络足够稳定,且
update时的采样足够密,这两个码率在数值上几乎完全重合。