Hybrid-aware speculative decoding with O(1) GDN state restore GDN 状態を O(1) 復元する hybrid 対応 speculative decoding
A speculative-decoding implementation that correctly handles Qwen 3.5's 24 GatedDeltaNet + 8 Full-Attention hybrid. We have not seen this approach in another open inference engine we tested. Qwen 3.5 の GDN 24 層 + FA 8 層 hybrid 構成を、黙って誤った出力に落とさずに正しく扱える speculative-decoding 実装。私たちが検証した範囲では、このアプローチを採用している open な推論エンジンを他に見ていません。
The problem 問題
Standard speculative decoding (Leviathan 2023, Medusa) targets pure transformers — on reject, you truncate the KV cache by N entries and continue. GDN layers carry a recurrent state and convolutional buffer that have already advanced through the entire draft window. Truncating only KV leaves GDN state corrupted, and the model silently produces divergent output. This is the implementation difficulty that, in the engines we checked, has kept speculative decoding off by default for Qwen 3.5 hybrid. 通常の speculative decoding (Leviathan 2023、Medusa) は pure transformer が前提 — reject 時に KV cache を N entry 切り詰めるだけで続行できる。しかし GDN 層は recurrent state と convolutional buffer が draft window 全体を進んでしまっている。KV だけ truncate しても GDN state は壊れたまま残り、モデルは clash せず静かに誤った出力を続ける。私たちが確認した engine では、この実装上の難しさゆえに Qwen 3.5 hybrid 用の speculative decoding はデフォルト無効となっています。
Our approach 私たちの方法
Before each verify call, snapshot every GDN layer's (recurrent_state, conv_buf) pair into a pre-allocated tensor pool. On rejection, restore from snapshot in O(1) — zero allocations in the hot path. GDN state is ~tens of KB per layer; 24 layers' snapshot completes well under 1 ms per verify. No additional MLX graph nodes, no memory fragmentation.
各 verify 呼び出しの前に、全 GDN 層の (recurrent_state, conv_buf) ペアを 事前確保した tensor pool に snapshot。reject 時は snapshot から O(1) で復元 — hot path で alloc は発生しない。GDN state は layer あたり数十 KB、24 層合計の snapshot でも verify あたり 1 ms 未満。MLX graph に追加ノードは不要、メモリ断片化なし。