花期易逝
Go back

浅谈压缩与Compressai库的学习

Published:  at  04:00 PM

记录下Compressai库的学习流程(基于Gemini和开源代码)

InterDigitalInc/CompressAI: A PyTorch library and evaluation platform for end-to-end compression research

模型整体架构

我们的学习遵循自顶向下的模式,先看整体的结构设计,结构树形图如下:

```text
compressai/

├── models/                   # 模型总装车间 (职责不变)
│   ├── base.py               #    - 核心基类 `CompressionModel`
│   ├── google.py             #    - Ballé, Minnen, Cheng 等经典图像压缩模型
│   └── video/                #    - 新增:视频压缩模型子目录
│       └── google.py         #        - SSF, DCVC 等视频压缩模型实现

├── latent_codecs/            # 隐层编解码器 (核心升级)
│   ├── base.py               #    - `LatentCodec` 基类,定义统一接口
│   ├── entropy_bottleneck.py #    - 全分解模型 (Factorized Prior)
│   ├── gaussian_conditional.py #  - 高斯条件模型
│   └── hyperprior.py         #    - 将超先验中 z 的处理逻辑封装成一个 LatentCodec

├── layers/                   # 自定义神经网络层 (职责不变)
│   └── gdn.py                #    - GDN / IGDN 实现

├── entropy_coding/           # 熵编码抽象层 (新!)
│   ├── __init__.py           #    - 封装了对不同后端 (ans, rangecoder) 的调用
│   ├── ans.py                #    - Python 端对 C++ ANS 编码器的封装
│   └── ...

├── cpp_exts/                 # C++ 性能扩展 (底层核心)
│   └── rans/                 #    - rANS (range Asymmetric Numeral Systems) 的 C++ 源码

├── optimizers/               # 优化器与参数化
│   └── parametrizers.py      #    - 例如 LowerBound,用于确保 GDN 参数为正

├── losses/                   # 损失函数
│   └── rate_distortion.py    #    - 核心的率-失真损失 `RateDistortionLoss`

└── zoo/                      # 预训练模型库
    └── __init__.py           #    - `load_state_dict_from_url` 加载网络权重的核心逻辑

通过抽象基类定义统一接口,利用多态性和动态绑定机制,实现调用方与具体实现方的解耦,从而构建一个可扩展、可维护的系统。

**models/base.py**实现抽象基类

entropy_coding/: 明确它是对底层的 C++ ansrange_coderPythonic 抽象,统一了 encode / decode 接口。

latent_codecs/: 强调它实现了量化、概率建模、熵编码的三位一体。

重构的目的:解耦与多态

剖析完整数据流动

使用核心训练脚本 examples/train.py追踪一次完整的数据流动:从原始图片到压缩比特流,再到损失计算与模型优化的全过程。

1.参数解析 (argparse)

脚本的 main 函数首先通过 Python 的 argparse 模块定义了所有可配置的超参数。几个关键参数包括:

2.核心组件的实例化

参数确定后,脚本开始实例化训练所需的四大核心组件:

3.train_one_epoch 中的数据流动

train_one_epoch 函数 是整个训练过程的核心。让我们跟随一批数据 d (一个 batch 的图像张量),看看它经历的完整旅程:

输入 (d): d 是一个形状为 [B, C, H, W] 的张量,代表 B 张图像。

手写Latent Codec

class AutoregressiveLatentCodec(LatentCodec):(继承基类)
    """
    一个简单的自回归隐层编解码器实现
    职责:管理量化、概率预测(上下文模型)以及熵编码
    """
    def __init__(self, latent_dim, **kwargs):
        super().__init__(**kwargs)
        # 上下文预测网络:通过 y_hat 预测概率参数 (mu, sigma)
        self.context_prediction = MaskedConv2d(
            latent_dim, latent_dim * 2, kernel_size=5, padding=2, stride=1
        )
        # 熵模型:高斯条件熵模型
        self.gaussian_conditional = GaussianConditional(None)

    def forward(self, y):
        # 训练阶段:利用 MaskedConv 的并行性
        # y: 原始隐层张量 [B, C, H, W]
        
        # 1. 模拟量化
        y_hat = self.quantize(y, mode="noise" if self.training else "dequantize")
        
        # 2. 预测概率分布参数
        # 掩模卷积保证了当前像素只能“看到”之前扫描到的像素
        params = self.context_prediction(y_hat)
        scales, means = params.chunk(2, 1)
        
        # 3. 计算似然概率(用于计算 Rate)
        _, likelihoods = self.gaussian_conditional(y, scales, means=means)
        
        return {
            "y_hat": y_hat,
            "likelihoods": {"y": likelihoods}
        }
        
def compress(self, y):
    # y: [1, C, H, W]
    # 1. 预先量化 (推理阶段必须是确定的 round)
    y_hat = torch.round(y)
    
    # 2. 准备状态:对于自回归,我们需要逐像素(或逐块)处理
    # 在这个例子中,我们假设使用最经典的 Raster-scan (光栅扫描) 顺序
    B, C, H, W = y_hat.shape
    
    # 这里的关键是:虽然 y_hat 已知,但为了模拟解码端的行为,
    # 我们通常需要维持一个上下文缓存
    # 下面是一个抽象逻辑,CompressAI 内部通常会优化这一过程
    
    # 3. 概率估计:在自回归下,每个像素点的概率取决于之前的像素
    # 由于 compress 时 y_hat 全已知,我们可以一次性通过 MaskedConv 得到所有均值和方差
    params = self.context_prediction(y_hat)
    scales, means = params.chunk(2, 1)
    
    # 4. 调用 C++ 后端进行编码
    # 注意:这里的编码器(ans)会将隐层打成字符串
    # y_hat: 实际符号, scales/means: 每个符号对应的概率模型参数
    cdfs = self.gaussian_conditional.get_cdf(scales, means)
    indices = self.gaussian_conditional.build_indexes(scales)
    
    stream = self.entropy_coder.encode_with_indexes(
        y_hat.view(-1).int(), indices.view(-1).int(), cdfs, ...
    )
    return stream

def decompress(self, strings, shape):
    # shape: (H, W) 隐层分辨率
    H, W = shape
    y_hat = torch.zeros((1, self.latent_dim, H, W)).to(device)
    
    # 🔓 核心痛点:必须写循环!
    # 因为 MaskedConv 的输出取决于 y_hat 中已经填充的部分
    for h in range(H):
        for w in range(W):
            # 1. 提取当前已解码部分,预测当前位置的分布参数
            # 这是一个典型的“因果预测”
            params = self.context_prediction(y_hat)
            scales, means = params.chunk(2, 1)
            
            # 2. 只取当前 (h, w) 位置的参数去解码一个符号
            s = scales[:, :, h:h+1, w:w+1]
            m = means[:, :, h:h+1, w:w+1]
            
            # 3. 从比特流中解码出当前像素值
            rv = self.gaussian_conditional.decode_with_indexes(...)
            
            # 4. 更新 y_hat,供下一个位置的预测使用
            y_hat[:, :, h, w] = rv + m
            
    return y_hat

写在训模型之外(核心算子)

先聊算术编码

核心思想是将整个消息序列映射为 $[0, 1)$ 区间上的一个实数

但是,尽管 AC 在理论上无限接近香农极限,但它在高性能端到端压缩(如 CompressAI)中存在致命伤:

  1. 精度爆炸(The Precision Crisis):随着序列变长,区间变得无限小,必须使用高精度浮点数或复杂的重正规化,计算开销巨大。
  2. 串行瓶颈:AC 是严格 FIFO(先进先出)的,且编码每一位都需要进行复杂的乘除法,难以利用 CPU/GPU 的并行特性。
  3. 状态耦合:编码下一位必须等待上一位区间更新完成。
AC到rANS

将 AC 的“区间收缩”变成了一个整数状态 $x$ 的累加

rANS 的核心思想是将整个消息序列编码为一个单一的正整数 $x$。这个 $x$ 即是我们的“编码状态”。

假设我们的隐层符号(Symbols)服从离散分布,字母表为 $\mathcal{A}$:

编码过程是将当前状态 $x$ 和新符号 $s$ 映射为新状态 $x’$。数学公式如下:

$$x’ = C(x, s) = \lfloor \frac{x}{f_s} \rfloor \cdot M + C_s + (x \pmod{f_s})$$

解码是编码的逆过程。给定状态 $x’$,我们首先要识别出它对应哪个符号 $s$:

  1. 找出符号:找到满足 $C_s \leq (x’ \pmod M) < C_{s+1}$ 的 $s$。

  2. 还原状态:

    $$ x = D(x’, s) = f_s \cdot \lfloor \frac{x’}{M} \rfloor + (x’ \pmod M) - C_s $$

补充:计算效率分析

rANS 的编码公式:

$$ x’ = \underbrace{\lfloor x / f_s \rfloor \cdot M}{\text{项 A}} + \underbrace{C_s}{\text{项 B}} + \underbrace{(x \pmod{f_s})}_{\text{项 C}} $$ 在 C++ 实现中,这行代码的执行逻辑如下:

  1. 除法 (/):计算 q = x / f_s。这一步同时得到了q余数 r = x % f_s。这是 CPU 的一条指令(DIV)。
  2. 位移 (<< L):因为 $M$ 通常设为 $2^{16}$,所以 q * M 变成了 q << 16
  3. 加法 (+):将 (q << 16)C_sr 相加。

结论:对于 CPU 来说,这确实就是 1 次除法 + 几次位运算/加法。相比 AC 复杂的区间缩放和多次判断,速度提升了几个数量级。

解码:$s = \text{find_s}(x \pmod M)$$

$$x_{next} = f_s \cdot \lfloor x / M \rfloor + (x \pmod M) - C_s$$

  1. 查找 (find_s):由于 $M=2^{16}$,x % M 就是 x & 0xFFFF。我们拿着这个 16 位的索引去 CDF 表里查,一步就能定位到符号 $s$。这就是你说的一次查找
  2. 乘法 (\*):公式中的 $\lfloor x / M \rfloor$ 是 x >> 16。剩下的操作就是 f_s * (x >> 16) + offset
  3. 结论:解码只需 1 次查表 + 1 次乘法
struct AnsSymbol {
    uint32_t start;      // C_s (累积频率)
    uint32_t freq;       // f_s (符号频率)
    // --- 工业级优化项 ---
    uint32_t rcp_freq;   // 频率的倒数定点数:(1ULL << 31) / freq
    uint32_t rcp_shift;  // 用于位移的偏移量
};
// 对应公式:x' = (x/f)*M + Cs + (x%f)
void RansEncPut(uint32_t* x, uint16_t** pptr, const AnsSymbol& sym) {
    uint32_t x_curr = *x;

    // 1. 重正规化 (Renormalization):防止状态 x 溢出 32 位
    // 工业库通常会在 x 达到一个阈值时,吐出 16 bits 到比特流
    uint32_t x_max = ((RANS_L >> 16) << 16) / sym.freq; 
    while (x_curr >= x_max) {
        *(*pptr)-- = (uint16_t)(x_curr & 0xffff); // 写入比特流(逆序)
        x_curr >>= 16;
    }

    // 2. 核心状态转移:消灭除法
    // 理论公式:q = x / sym.freq; r = x % sym.freq;
    // 工业优化:使用预计算的倒数进行定点数乘法位移
    uint32_t q = (uint32_t)(((uint64_t)x_curr * sym.rcp_freq) >> sym.rcp_shift);
    uint32_t r = x_curr - q * sym.freq; 

    // 3. 更新状态:M = 1 << 16
    *x = (q << 16) + sym.start + r;
}
// 对应公式:x = f * floor(x/M) + (x%M) - Cs
uint32_t RansDecGet(uint32_t* x, uint16_t** pptr, const AnsSymbol* sym_table) {
    uint32_t x_curr = *x;
    
    // 1. 获取当前插槽位置 (slot = x % M)
    uint32_t slot = x_curr & 0xffff;

    // 2. 一次查找获取符号及其元数据
    const AnsSymbol& sym = sym_table[slot]; 

    // 3. 还原状态:x = f * (x >> 16) + slot - Cs
    x_curr = sym.freq * (x_curr >> 16) + slot - sym.start;

    // 4. 重正规化(回填):从比特流中读取 16 bits 恢复状态
    while (x_curr < RANS_L) {
        x_curr = (x_curr << 16) | *(*pptr)++;
    }

    *x = x_curr;
    return slot; // 返回解码出的值
}

加餐1:聊一聊熵模型与高斯假设

熵建模的本质:从概率到码率

在压缩系统中,**熵模型(Entropy Model)**的任务是预测 Latent $y$ 的概率分布 $q(y)$。

码流中传输的到底是什么?

这是一个关于“空间切分”的物理过程,而非简单的数值传输。

为了应对真实图像中复杂的、非高斯的 Latent 分布,GMM 引入了“多重押注”机制。

加餐2:真实的编码落地

均值归一化 (Zero-Mean Normalization)

图中的 Normalization 步骤是整个工程优化的精髓。

编码器内的“受损视角”同步

观察图中 Encoder 内部的 Z_string -> Z_dec 链路。

Update 产生的 CDF 表是“物理契约”

Update 阶段生成的 CDF_Table 是跨平台、跨设备编解码的唯一准则。

码率惩罚项的“愿景”与“实操”