花期易逝
Go back

图像压缩中的串行和并行

Published:  at  08:00 AM

在端到端图像压缩(Learned Image Compression)中,自回归上下文模型(Autoregressive Context Model)显著提升了率失真性能,但也引入了“训练并行”与“推理串行”的工程悖论。ELIC (CVPR 2022) 通过通道分组(Channel Groups)和棋盘格掩膜(Checkerboard Masking)优化了这一权衡。本文结合源码,深入剖析其底层的数学差异:推理阶段的串行性源于“均值辅助量化(Mean-Shifted Quantization)”构建的递归死锁;而训练阶段的并行性则源于“开环量化代理(Open-Loop Quantization Proxy)”策略,该策略解耦了上下文输入的递归依赖,从而允许通过掩膜卷积实现全图并行计算。

推理阶段:均值辅助量化引发的递归死锁

在推理(编码)阶段,为了最大化压缩效率,熵编码器(Entropy Coder)要求离散符号的分布中心严格对齐至 0。因此,量化过程必须依赖上下文模型预测的均值 mu。

1.1 数学原理:递归依赖链

推理时的量化公式如下:
y_hat_i = Round(y_i - mu_i) + mu_i

这一公式构建了一个无法逾越的递归死锁(Recursive Deadlock):

  1. 量化依赖:计算离散隐变量 y_hat_i 必须先获得预测均值 mu_i
  2. 预测依赖:计算 mu_i 必须以之前的离散隐变量 y_hat_<i 为输入。
  3. 结论:在 t 时刻的 mu_t 计算完成前,t 时刻的输入数据 y_hat_t 物理上是不存在的。

1.2 源码实证:通道分组的强制串行

ELIC 通过显式的通道切分锁死了这种依赖。Elic2022Official 类的 channel_context 定义展示了 Group_kGroup_0...k-1 的严格依赖。

代码逻辑对应:
定义通道上下文模型时,第 k 组网络的输入通道数为 sum(self.groups[:k])。这表明必须先完成前 k-1 组的串行量化与解码,拼接成完整的张量后,才能作为当前网络的输入。若试图并行计算 y{1},此时 y{0} 尚未完成均值辅助量化,因此 channel_context['y1'] 的输入数据缺失,计算图无法构建。

训练阶段:开环量化代理下的并行机制

训练阶段的目标是计算损失函数 L = R + lambda * D。ELIC 通过“开环量化代理(Open-Loop Quantization Proxy)”策略打破了上述死锁。

2.1 核心逻辑:输入源的解耦

训练时,Context Model 的输入 y_hat_input 不再是递归计算出的 Round(y - mu) + mu,而是直接源自 Encoder 输出的 y(经过 STE 截断)。这种策略在学术上被称为 Ground-Truth Context Injection

公式表达为:
y_hat_input ≈ y (通过 STE)

关键推论:

  1. 无递归依赖y_hat_input 的数值仅取决于 Encoder 的参数,完全不依赖 Context Model 当前计算出的 mu。即 y_hat_inputmu_current 正交。
  2. 即时可用性:在计算图的前向传播初期,全图的 y_hat_input 张量已由 Encoder 一次性生成并驻留在显存中。

2.2 源码实证 A:STE 策略实现输入解耦

代码中 quantizer="ste" 的配置印证了输入源的改变。使用 “ste” (Straight-Through Estimator) 意味着训练时的量化输入直接对 Encoder 输出 y 进行截断并传递梯度。这一过程不涉及减去 mu (Centered),从而切断了对 mu 的递归依赖。

2.3 源码实证 B:掩膜卷积实现全图并行

由于 y_hat_input 是全图可用的,GPU 可以通过掩膜卷积(Masked Convolution)一次性计算所有位置的 musigma

代码逻辑对应:
实例化 CheckerboardMaskedConv2d 证明了输入是全图张量 (Full Tensor)。Mask 矩阵 (Toeplitz Matrix) 在并行矩阵乘法中,确保了计算位置 (h,w)mu 时,权重矩阵只能“看到” (h,w) 之前的数据。

解析:

  1. 数据准备:Encoder 输出全图 y_hat_ste
  2. 并行运算spatial_context (Masked Conv)y_hat_ste 执行一次卷积操作。
  3. 结果产出:输出张量包含了全图每个位置的 musigma
  4. 码率估计:基于这些并行算出的 mu, sigma 和已有的 y_hat_ste,直接代入高斯概率公式计算 Loss。

深度逻辑复盘:从源码看数据流差异

我们将源码逻辑拆解为以下因果链条,以回应关于并行与串行的核心疑问:

3.1 训练流 (Training Flow) 逻辑核心

由于训练时的量化代理(Quantization Proxy)不依赖于 mu,因此全图输入即时可用,可通过掩膜卷积一次性计算全图的 musigma,进而实现并行的码率估计。

  1. Step 1:全局代理生成 Encoder 输出 y,经 STE 得到 y_hat_ste。此过程无递归依赖。对应源码中的 GaussianConditionalLatentCodec(quantizer="ste")
  2. Step 2:并行参数估计 将完整的 y_hat_ste 喂入 Masked Conv,GPU 一次性算出全图参数。对应源码中 CheckerboardMaskedConv2d 的前向传播。
  3. Step 3:并行码率估计 直接计算概率密度,此过程为纯矩阵运算,完全并行。

3.2 推理流 (Inference Flow) 逻辑核心

推理中,解码端的串行源于数据缺失;编码端的串行源于量化闭环——y_hat 的生成依赖 mu,而 mu 的预测又依赖之前的 y_hat

结论

ELIC 架构的高效性建立在“训练并行,推理串行”的工程不对称性之上。这一不对称性的根源在于:

  1. 训练时:利用开环量化代理(Open-Loop Quantization Proxy)和 STE 策略,在数学上切断了量化过程对 mu 的递归依赖,使得全图数据在逻辑上即时可用,从而通过掩膜卷积实现全并行计算。
  2. 推理时:为了追求极致的熵编码效率,恢复了均值辅助量化(Mean-Shifted Quantization)的递归依赖,迫使编码器必须模拟解码器的时序行为,退化为串行模式。
@register_model("elic2022-official")
class Elic2022Official(SimpleVAECompressionModel):
    """ELIC 2022; uneven channel groups with checkerboard spatial context.



    .. code-block:: none

                  ┌───┐    y     ┌───┐  z  ┌───┐ z_hat      z_hat ┌───┐
            x ──►─┤g_a├──►─┬──►──┤h_a├──►──┤ Q ├───►───·⋯⋯·───►───┤h_s├─┐
                  └───┘    │     └───┘     └───┘        EB        └───┘ │
                           ▼                                            │
                         ┌─┴─┐                                          │
                         │ Q │                                   params ▼
                         └─┬─┘                                          │
                     y_hat ▼                  ┌─────┐                   │
                           ├──────────►───────┤  CP ├────────►──────────┤
                           │                  └─────┘                   │
                           ▼                                            ▼
                           │                                            │
                           ·                  ┌─────┐                   │
                        GC : ◄────────◄───────┤  EP ├────────◄──────────┘
                           ·     scales_hat   └─────┘
                           │      means_hat
                     y_hat ▼

                  ┌───┐    │
        x_hat ──◄─┤g_s├────┘
                  └───┘

        EB = Entropy bottleneck
        GC = Gaussian conditional
        EP = Entropy parameters network
        CP = Context prediction (masked convolution)


    Args:
        N (int): Number of main network channels
        M (int): Number of latent space channels
        groups (list[int]): Number of channels in each channel group
    """

    def __init__(self, N=192, M=320, groups=None, **kwargs):
        super().__init__(**kwargs)

        if groups is None:
            groups = [16, 16, 32, 64, M - 128]

        self.groups = list(groups)
        assert sum(self.groups) == M

        self.g_a = nn.Sequential(
            conv(3, N, kernel_size=5, stride=2),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            conv(N, N, kernel_size=5, stride=2),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            AttentionBlock(N),
            conv(N, N, kernel_size=5, stride=2),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            conv(N, M, kernel_size=5, stride=2),
            AttentionBlock(M),
        )

        self.g_s = nn.Sequential(
            AttentionBlock(M),
            deconv(M, N, kernel_size=5, stride=2),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            deconv(N, N, kernel_size=5, stride=2),
            AttentionBlock(N),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            deconv(N, N, kernel_size=5, stride=2),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            ResidualBottleneckBlock(N, N),
            deconv(N, 3, kernel_size=5, stride=2),
        )

        h_a = nn.Sequential(
            conv(M, N, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            conv(N, N, kernel_size=5, stride=2),
            nn.ReLU(inplace=True),
            conv(N, N, kernel_size=5, stride=2),
        )

        h_s = nn.Sequential(
            deconv(N, N, kernel_size=5, stride=2),
            nn.ReLU(inplace=True),
            deconv(N, N * 3 // 2, kernel_size=5, stride=2),
            nn.ReLU(inplace=True),
            deconv(N * 3 // 2, N * 2, kernel_size=3, stride=1),
        )

        # In [He2022], this is labeled "g_ch^(k)".
        channel_context = {
            f"y{k}": nn.Sequential(
                conv(sum(self.groups[:k]), 224, kernel_size=5, stride=1),
                nn.ReLU(inplace=True),
                conv(224, 128, kernel_size=5, stride=1),
                nn.ReLU(inplace=True),
                conv(128, self.groups[k] * 2, kernel_size=5, stride=1),
            )
            for k in range(1, len(self.groups))
        }

        # In [He2022], this is labeled "g_sp^(k)".
        spatial_context = [
            CheckerboardMaskedConv2d(
                self.groups[k],
                self.groups[k] * 2,
                kernel_size=5,
                stride=1,
                padding=2,
            )
            for k in range(len(self.groups))
        ]

        # In [He2022], this is labeled "Param Aggregation".
        param_aggregation = [
            sequential_channel_ramp(
                # Input: spatial context, channel context, and hyper params.
                self.groups[k] * 2 + (k > 0) * self.groups[k] * 2 + N * 2,
                self.groups[k] * 2,
                min_ch=N * 2,
                num_layers=3,
                interp="linear",
                make_layer=nn.Conv2d,
                make_act=lambda: nn.ReLU(inplace=True),
                kernel_size=1,
                stride=1,
                padding=0,
            )
            for k in range(len(self.groups))
        ]

        # In [He2022], this is labeled the space-channel context model (SCCTX).
        # The side params and channel context params are computed externally.
        scctx_latent_codec = {
            f"y{k}": CheckerboardLatentCodec(
                latent_codec={
                    "y": GaussianConditionalLatentCodec(quantizer="ste"),
                },
                context_prediction=spatial_context[k],
                entropy_parameters=param_aggregation[k],
            )
            for k in range(len(self.groups))
        }

        # [He2022] uses a "hyperprior" architecture, which reconstructs y using z.
        self.latent_codec = HyperpriorLatentCodec(
            latent_codec={
                # Channel groups with space-channel context model (SCCTX):
                "y": ChannelGroupsLatentCodec(
                    groups=self.groups,
                    channel_context=channel_context,
                    latent_codec=scctx_latent_codec,
                ),
                # Side information branch containing z:
                "hyper": HyperLatentCodec(
                    entropy_bottleneck=EntropyBottleneck(N),
                    h_a=h_a,
                    h_s=h_s,
                    quantizer="ste",
                ),
            },
        )

    @classmethod
    def from_state_dict(cls, state_dict):
        """Return a new model instance from `state_dict`."""
        N = state_dict["g_a.0.weight"].size(0)
        net = cls(N)
        net.load_state_dict(state_dict)
        return net