Transformer Cheat Sheet

Easier transformer diagrams

Author

Enzo Shiraishi

Published

January 18, 2025

\[ \text{Enc}(pos, i) = \frac{pos}{\theta^{\frac{2i}{d}}} \\ \text{PE}(pos, i) = \begin{cases} \sin \text{Enc}(pos, i) & \text{if } index \text{ is even}, \\ \cos \text{Enc}(pos, i) & \text{if } index \text{ is odd}. \end{cases} \]

\[ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]

---
title: Transformer
---
flowchart LR
    input(Input Processing) --> encoder(Encoder) --> decoder(Decoder) --> output(Output Processing)

---
title: Input Processing
---
flowchart LR
    inputString(Input) --> oneHotEncoding(One-hot Encoding) --> pose(Positional Encoding) --> modelInput(Output)

---
title: Transformer block
---
flowchart LR
    Value(Value)
    Query(Query)
    Key(Key)
    Sum1(Sum)
    Sum2(Sum)
    MHA(Multi-Head Attention)
    LayerNorm1(LayerNorm)
    LayerNorm2(LayerNorm)
    Linear1(Linear)
    Linear2(Linear)
    ReLU(ReLU)
    output(Output)
    Query --> MHA
    Key --> MHA
    Key ~~~ Sum1
    Value --> MHA
    Value --> Sum1
    MHA --> Sum1
    Sum1 --> LayerNorm1
    LayerNorm1 --> Linear1
    Linear1 --> ReLU
    ReLU --> Linear2
    LayerNorm1 --> Sum2
    Linear2 --> Sum2
    Sum2 --> LayerNorm2
    LayerNorm2 --> output

---
title: Encoder
---
flowchart LR
    input(Input)
    EncoderBlock(Transformer Block #1)
    EncoderBlock2(Transformer Block #2)
    EncoderBlockN(Transformer Block #M)
    outputN(Output)
    input --> EncoderBlock
    EncoderBlock --> EncoderBlock2
    EncoderBlock2 -- ... --> EncoderBlockN
    EncoderBlockN --> outputN

---
title: Decoder block
---
flowchart LR
    encoderOutput(Encoder Output)
    Query(Query)
    Key(Key)
    Value(Value)
    MHA(Multi-Head Attention)
    DecoderBlock(Transformer Block)
    decoderBlockOutput(Output)
    Query --> MHA
    Key --> MHA
    Value --> MHA
    encoderOutput(Encoder Output) ~~~ MHA
    MHA -- Value --> DecoderBlock
    encoderOutput -- Query, Key --> DecoderBlock
    DecoderBlock --> decoderBlockOutput

---
title: Decoder
---
flowchart LR
    encoderOutput(Encoder Output)
    input(Input)
    attentionMask(Attention Mask)
    Prod(Product)
    DecoderBlock1(Decoder Block #1)
    DecoderBlock2(Decoder Block #2)
    DecoderBlockN(Decoder Block #N)
    output(Output)
    
    attentionMask ~~~ encoderOutput
    input ~~~ encoderOutput
    input --> Prod
    attentionMask --> Prod
    Prod -- Query, Key, Value --> DecoderBlock1
    encoderOutput --> DecoderBlock1

    DecoderBlock1 -- Query, Key, Value --> DecoderBlock2
    encoderOutput --> DecoderBlock2
    
    DecoderBlock2 -- ... --> DecoderBlockN
    encoderOutput --> DecoderBlockN
    DecoderBlockN --> output

---
title: Output Processing
---
flowchart-elk LR
    DecoderBlockN(Input)
    linear(Linear)
    Softmax(Softmax)
    Argmax(Argmax)
    outputString(Output)
    DecoderBlockN --> linear
    linear --> Softmax
    Softmax --> Argmax
    Argmax --> outputString