Mixture-of-Channels: Exploiting Sparse FFNs for Efficient LLMs Pre-Training and Inference

stato della ricerca deep learning

Perché “Mixture-of-Channels”

Il paper parte da un’osservazione concreta: con FlashAttention, il vero collo di bottiglia di memoria non è più l’attention ma le attivazioni della FFN, soprattutto in pre-training con batch grandi e sequenze lunghe. Mixture-of-Channels (MoC) modifica la FFN “classica” usando il gating nativo di SwiGLU per attivare solo i canali Top-K per token, riducendo drasticamente le attivazioni da memorizzare e l’accesso a memoria in inferenza. In sintesi, MoC porta a sostanziali risparmi di memoria e velocità di decoding più alte, mantenendo prestazioni competitive su benchmark standard e su diverse architetture di LLM.

Paper

Indice

Dentro il “mixture”: tecniche e scelte

MoC osserva che molte uscite della SwiGLU sono prossime a zero, quindi selezionare per ogni token solo i canali più “attivi” (Top-K) evita di memorizzare attivazioni inutili per il backward e di accedere a pesi inutili in decoding. In addestramento, si memorizzano solo le parti mascherate e si ricostruisce il resto quando serve; in inferenza, si caricano da HBM a SRAM solo righe/colonne dei pesi relative ai canali attivi. Implementazioni hardware-aware (kernel RAFT per Top-K batched e Triton fusi) e l’opzione di structured sparsity 2:8 su GPU Ampere/Hopper concretizzano i guadagni teorici in speedup reali.

Cosa cambia nei numeri

Gli autori profilano LLaMA-like con FlashAttention e mostrano che la FFN domina la memoria attivazioni, motivando il focus sul suo redesign. MoC riduce l’attivazione FFN da un ordine “dense” a uno “sparse” con Top-K per token, ottenendo riduzioni significative della memoria totale di training su modelli 60M-1B e migliorando l’end-to-end decoding throughput di circa 1.13x, con accelerazioni del singolo layer FFN di circa 1.38x grazie al minor I/O sui pesi. Su zero-shot (MMLU, ARC, PIQA, TruthfulQA) le performance restano vicine o competitive con LLaMA baseline a parità di scala, confermando che l’efficienza non compromette la qualità in modo sostanziale.

Leggere il paper senza inciampare

  • Attivazioni vs pesi/stati ottimizzatore: qui il problema chiave è la memoria delle attivazioni, non solo la riduzione dei pesi o degli stati di Adam, ed è per questo che MoC mira direttamente alla FFN.
  • Sparsità per-token nella FFN: a differenza delle MoE che selezionano “esperti”, MoC seleziona canali interni della stessa FFN per ogni token, sfruttando il gating di SwiGLU come guida naturale alla rilevanza.
  • Orthogonalità a FlashAttention e gradient checkpointing: MoC si combina con FlashAttention e con gradient checkpointing, sommando benefici senza richiedere cambiamenti all’attention o alla pipeline di training.

Scomporre i risultati, con contesto

  • Profiling: con FlashAttention, la quota di memoria della FFN supera nettamente quella dell’attention, rendendo la FFN il target più efficace per interventi di efficienza.
  • Pre-training: su C4, MoC e MoC 2:8 abbassano significativamente la memoria (inclusi pesi, gradienti, stati dell’ottimizzatore e attivazioni), mantenendo perplexity in linea con strong baselines su diverse scale di modello.
  • Inference: il carico MAC e, soprattutto, gli accessi a memoria si riducono perché si moltiplicano e si accumulano solo K canali invece di tutta la dimensione FFN, accelerando la latenza di decoding e l’end-to-end throughput.

“Dentro MoC”: dettagli di approccio

  • Top-K guidato da SwiGLU: il gating nativo ordina i canali per importanza per token, e la maschera binaria attiva solo i canali che contano, senza euristiche post-hoc o threshold dinamici.
  • Kernel e strutture: kernel per Top-K batched e FFN fusi in Triton minimizzano overhead, mentre la 2:8 structured sparsity sfrutta il supporto hardware delle GPU moderne per sbloccare speedup aggiuntivi.
  • Compatibilità ampia: MoC si integra con LLaMA, GQA, Qwen3 e anche con MoE come Mixtral sostituendo la MLP dell’esperto con la variante MoC, mantenendo qualità e riducendo memoria.

Quiz lampo su “Mixture-of-Channels”

  • In una pipeline con FlashAttention, qual è il principale collo di bottiglia di memoria che MoC prova a ridurre? Risposta: la memoria delle attivazioni della FFN nei Transformer LLM.
  • Qual è l’idea chiave di MoC in una riga? Risposta: attivare solo i canali Top-K per token usando il gating di SwiGLU, memorizzando/accendendo solo ciò che serve.
  • Perché MoC accelera l’inferenza? Risposta: riduce gli accessi a memoria caricando solo pesi dei canali attivi e diminuendo i MAC effettivi nella FFN.
  • MoC sostituisce FlashAttention? Risposta: no, è ortogonale e si somma a FlashAttention perché agisce sulla FFN, non sull’attention.
  • Quale vantaggio offre la structured sparsity 2:8? Risposta: allinea la sparsità di MoC con primitive hardware delle GPU per trasformare la sparsità in speedup pratici.

Studi vicini: chi fa cosa

  • Grouped-Query Attention (GQA): riduce il costo dell’attention condividendo key/value tra gruppi di query heads, complementare a MoC che agisce nella FFN; utile per KV-cache e latenza in decoding.
  • GPTQ: quantizzazione post-training “weight-only” a 3-4 bit ad alta fedeltà, utile in deployment per comprimere i pesi senza riaddestramento esteso, complementare alla riduzione attivazioni di MoC.
  • AWQ: quantizzazione “activation-aware” per LLM che preserva i pesi più critici, dimostrata efficace e hardware-friendly, anch’essa complementare all’approccio di sparsità di MoC.

“Perché è interessante” – takeaway

MoC mostra che la sparsità guidata dal gating interno di SwiGLU è un principio semplice ma potente per colpire il vero collo di bottiglia di memoria in era FlashAttention, con benefici concreti sia in training sia in inference. La proposta è pratica, si integra con tecniche di sistema e hardware moderne, ed è compatibile con altre linee di efficienza come quantization e attention ottimizzata, rendendola rilevante per la progettazione di LLM efficienti nel 2025.

Torna in alto