Trainable Parallel Decoding
Last updated: 2025-11-10
Trainable Parallel Decoding is a novel approach to accelerate Diffusion Large Language Models (DLLMs) by learning to decode multiple tokens simultaneously during training, thereby reducing inference latency while maintaining generation quality.
Overview
Traditional DLLMs suffer from high inference latency due to their iterative, multi-step sampling process. Trainable Parallel Decoding addresses this limitation by introducing a second-stage fine-tuning paradigm that teaches the model to predict multiple future tokens in a single forward pass. This approach transforms the sequential generation process into a more parallelizable one, significantly reducing the number of required sampling steps.
The framework currently supports two complementary techniques:
Path Distillation (Trajectory Compression): Learning to jump between non-consecutive states in optimal generation trajectories
DPARALLEL: Entropy-based loss regularization to accelerate parallel decoding learning
Path Distillation (Trajectory Compression)
Path Distillation is motivated by the key observation from Song et al., 2025 that training on high-quality generation paths can significantly improve model efficiency. The method consists of two main stages:
High-Quality Trajectory Distillation
The first stage involves creating a dataset of “golden” trajectories through the following process:
Trajectory Generation: Use a pre-trained DLLM to sample generation paths on a domain-specific dataset (e.g., 200,000 math problems)
Quality Filtering: Apply an external verifier to filter trajectories that produce correct outputs
Dataset Construction: Retain only high-quality trajectories that pass verification
Mathematically, given a trajectory \(\tau = (s_N, s_{N-1}, \dots, s_0)\) representing states from fully masked to final output, we filter:
where \(V(\cdot)\) is the external verifier function.
Compressed Transition Learning
The second stage fine-tunes the model to predict multi-step transitions instead of single-step ones:
Training Instance Construction: For each trajectory, randomly sample timestamps \(i\) and \(j\) where \(N \ge i > j \ge 0\)
Target Identification: The model learns to predict tokens that are [MASK] in \(s_i\) but revealed in \(s_j\)
Loss Optimization: Minimize the negative log-likelihood of compressed transitions
The fine-tuning objective is:
where \(\Delta_{i \to j} = M_i \setminus M_j\) represents the indices of tokens to be predicted.
Implementation Details
The data preparation process involves:
Offline Dataset Creation: Generate and filter trajectories offline
Data Format: Prepare input_ids, noisy_input_ids, and labels for training
Training Configuration: Use standard SFT training with the compressed transition objective
The training data format should include:
input_ids: The starting state \(s_i\) with appropriate maskingnoisy_input_ids: The noised version of \(s_i\)labels: The target tokens to predict (tokens in \(s_j\) that differ from \(s_i\))
DPARALLEL: Learnable Parallel Decoding
Chen et al., 2025 introduce dParallel, a novel approach that incorporates an entropy-based regularization term into the training loss to encourage parallel decoding capabilities.
Methodology
The key insight is that by adding a confidence-based loss term during supervised fine-tuning, we can guide the model toward making confident, parallel predictions. This is achieved through:
Entropy Regularization: Add a loss term based on the entropy of the model’s predictions
Confidence Scoring: Use prediction confidence as a signal for parallel decoding quality
Loss Balancing: Combine the standard cross-entropy loss with the confidence-based term
Configuration
To enable DPARALLEL, use the following training configuration:
sh train.sh tasks/train_llada2_bd_with_dparallel.py configs/sft/llada2_mini_bd_sft.yaml --train.confidence_beta {confidence_beta}
Where:
confidence_betacontrols the strength of the entropy regularization (recommended value: 2.0)Higher values encourage more aggressive parallel decoding
The parameter balances between generation quality and speed-up
Training Process
The DPARALLEL training process:
Standard SFT Setup: Begin with standard supervised fine-tuning
Loss Modification: Add the confidence-based regularization term
Hyperparameter Tuning: Adjust
confidence_betabased on desired speed-quality trade-offEvaluation: Monitor both generation quality and inference speed metrics