feat: 增加transformer 基础

This commit is contained in:
ViperEkura 2026-04-11 14:40:11 +08:00
parent 09dc6688b0
commit 9bf4437226
1 changed files with 795 additions and 0 deletions

View File

@ -0,0 +1,795 @@
"""
Transformer 架构可视化
使用Manim绘制Transformer的详细结构图
"""
from manim import *
import numpy as np
class TransformerArchitecture(Scene):
"""Transformer完整架构可视化"""
def construct(self):
self.show_title()
self.show_overall_architecture()
self.show_encoder_details()
self.show_decoder_details()
self.show_attention_mechanism()
self.show_attention_formula()
def show_title(self):
"""显示标题"""
title = Text(
"Transformer 架构详解",
font_size=48,
font="SimHei",
color=WHITE
)
title.to_edge(UP, buff=0.5)
subtitle = Text(
"现代大语言模型的核心基础",
font_size=24,
font="SimHei",
color=GRAY
)
subtitle.next_to(title, DOWN, buff=0.3)
self.play(Write(title), run_time=1.0)
self.play(Write(subtitle), run_time=0.8)
self.wait(1.0)
self.title = title
self.subtitle = subtitle
def show_overall_architecture(self):
"""显示整体架构图"""
# 清除标题
self.play(
FadeOut(self.title),
FadeOut(self.subtitle),
run_time=0.5
)
# 整体架构标题
arch_title = Text(
"Transformer 整体架构",
font_size=36,
font="SimHei",
color=BLUE
)
arch_title.to_edge(UP, buff=0.5)
self.play(Write(arch_title), run_time=0.8)
# 创建架构图的主要组件
# 输入部分
input_box = Rectangle(
width=3.0, height=1.0,
color=GREEN,
fill_opacity=0.3
)
input_box.to_edge(LEFT, buff=2.0).shift(UP * 0.5)
input_text = Text(
"输入序列",
font_size=20,
font="SimHei",
color=WHITE
)
input_text.move_to(input_box)
# 编码器堆叠
encoder_stack = VGroup()
for i in range(6):
encoder = Rectangle(
width=3.0, height=0.5,
color=BLUE_C,
fill_opacity=0.2
)
encoder.shift(RIGHT * 2.5 + UP * (1.5 - i * 0.55))
encoder_label = Text(
f"编码器层 {i+1}",
font_size=14,
font="SimHei"
)
encoder_label.move_to(encoder)
encoder_group = VGroup(encoder, encoder_label)
encoder_stack.add(encoder_group)
# 解码器堆叠
decoder_stack = VGroup()
for i in range(6):
decoder = Rectangle(
width=3.0, height=0.5,
color=ORANGE,
fill_opacity=0.2
)
decoder.shift(RIGHT * 6.5 + UP * (1.5 - i * 0.55))
decoder_label = Text(
f"解码器层 {i+1}",
font_size=14,
font="SimHei"
)
decoder_label.move_to(decoder)
decoder_group = VGroup(decoder, decoder_label)
decoder_stack.add(decoder_group)
# 输出部分
output_box = Rectangle(
width=3.0, height=1.0,
color=RED,
fill_opacity=0.3
)
output_box.to_edge(RIGHT, buff=2.0).shift(UP * 0.5)
output_text = Text(
"输出序列",
font_size=20,
font="SimHei",
color=WHITE
)
output_text.move_to(output_box)
# 箭头连接
arrow1 = Arrow(
input_box.get_right(),
encoder_stack.get_left(),
buff=0.2,
color=YELLOW
)
arrow2 = Arrow(
encoder_stack.get_right(),
decoder_stack.get_left(),
buff=0.2,
color=YELLOW
)
arrow3 = Arrow(
decoder_stack.get_right(),
output_box.get_left(),
buff=0.2,
color=YELLOW
)
# 显示所有组件
self.play(
Create(input_box),
Write(input_text),
run_time=0.8
)
self.wait(0.3)
# 显示编码器
for i, encoder in enumerate(encoder_stack):
self.play(
Create(encoder),
run_time=0.2
)
if i == 0:
self.play(Create(arrow1), run_time=0.5)
self.wait(0.3)
# 显示解码器
for i, decoder in enumerate(decoder_stack):
self.play(
Create(decoder),
run_time=0.2
)
if i == 0:
self.play(Create(arrow2), run_time=0.5)
self.wait(0.3)
# 显示输出
self.play(
Create(output_box),
Write(output_text),
Create(arrow3),
run_time=0.8
)
# 添加说明
explanation = Text(
"Transformer = 编码器(N层) + 解码器(N层)",
font_size=22,
font="SimHei",
color=YELLOW
)
explanation.to_edge(DOWN, buff=1.0)
self.play(Write(explanation), run_time=1.0)
self.wait(2.0)
# 保存引用
self.arch_title = arch_title
self.arch_components = [input_box, input_text, encoder_stack, decoder_stack,
output_box, output_text, arrow1, arrow2, arrow3, explanation]
def show_encoder_details(self):
"""显示编码器层细节"""
# 清除整体架构
self.play(
FadeOut(self.arch_title),
*[FadeOut(c) for c in self.arch_components],
run_time=0.8
)
# 编码器详细结构
encoder_title = Text(
"编码器层内部结构",
font_size=36,
font="SimHei",
color=BLUE
)
encoder_title.to_edge(UP, buff=0.5)
self.play(Write(encoder_title), run_time=0.8)
# 编码器层框
encoder_layer = Rectangle(
width=5.0, height=6.0,
color=BLUE,
fill_opacity=0.1
)
encoder_layer.center()
# 输入箭头
input_arrow = Arrow(
encoder_layer.get_top() + UP * 0.5,
encoder_layer.get_top(),
color=GREEN
)
input_label = Text(
"输入",
font_size=18,
font="SimHei",
color=GREEN
)
input_label.next_to(input_arrow, UP, buff=0.1)
# 层归一化1
ln1_box = Rectangle(
width=4.0, height=0.8,
color=PURPLE,
fill_opacity=0.3
)
ln1_box.move_to(encoder_layer.get_top() + DOWN * 1.0)
ln1_text = Text(
"层归一化 (LayerNorm)",
font_size=16,
font="SimHei",
color=WHITE
)
ln1_text.move_to(ln1_box)
# 多头注意力
mha_box = Rectangle(
width=4.0, height=1.2,
color=YELLOW,
fill_opacity=0.3
)
mha_box.move_to(ln1_box.get_bottom() + DOWN * 1.0)
mha_text = Text(
"多头自注意力\n(Multi-Head Attention)",
font_size=16,
font="SimHei",
color=WHITE
)
mha_text.move_to(mha_box)
# 残差连接1 - 从输入到多头注意力输出(右侧垂直箭头)
# 使用相同的x坐标确保箭头垂直
right_side_x = ln1_box.get_right()[0] + 1.5
residual1_start = np.array([right_side_x, ln1_box.get_top()[1] + 0.1, 0])
residual1_end = np.array([right_side_x, mha_box.get_bottom()[1] - 0.1, 0])
residual1 = Arrow(
residual1_start,
residual1_end,
color=RED,
buff=0.1
)
residual1_label = Text(
"残差连接",
font_size=14,
font="SimHei",
color=RED
)
residual1_label.next_to(residual1, RIGHT, buff=0.1)
# 层归一化2
ln2_box = Rectangle(
width=4.0, height=0.8,
color=PURPLE,
fill_opacity=0.3
)
ln2_box.move_to(mha_box.get_bottom() + DOWN * 1.5)
ln2_text = Text(
"层归一化 (LayerNorm)",
font_size=16,
font="SimHei",
color=WHITE
)
ln2_text.move_to(ln2_box)
# 前馈网络
ffn_box = Rectangle(
width=4.0, height=1.2,
color=GREEN,
fill_opacity=0.3
)
ffn_box.move_to(ln2_box.get_bottom() + DOWN * 1.0)
ffn_text = Text(
"前馈神经网络\n(Feed Forward Network)",
font_size=16,
font="SimHei",
color=WHITE
)
ffn_text.move_to(ffn_box)
# 残差连接2 - 从层归一化2输入到前馈网络输出右侧垂直箭头
# 使用相同的x坐标确保箭头垂直
right_side_x2 = ln2_box.get_right()[0] + 1.5
residual2_start = np.array([right_side_x2, ln2_box.get_top()[1] + 0.1, 0])
residual2_end = np.array([right_side_x2, ffn_box.get_bottom()[1] - 0.1, 0])
residual2 = Arrow(
residual2_start,
residual2_end,
color=RED,
buff=0.1
)
residual2_label = Text(
"残差连接",
font_size=14,
font="SimHei",
color=RED
)
residual2_label.next_to(residual2, RIGHT, buff=0.1)
# 输出箭头
output_arrow = Arrow(
encoder_layer.get_bottom(),
encoder_layer.get_bottom() + DOWN * 0.5,
color=GREEN
)
output_label = Text(
"输出",
font_size=18,
font="SimHei",
color=GREEN
)
output_label.next_to(output_arrow, DOWN, buff=0.1)
# 连接箭头
arrow_ln1_mha = Arrow(
ln1_box.get_bottom(),
mha_box.get_top(),
buff=0.1,
color=WHITE
)
arrow_mha_ln2 = Arrow(
mha_box.get_bottom(),
ln2_box.get_top(),
buff=0.1,
color=WHITE
)
arrow_ln2_ffn = Arrow(
ln2_box.get_bottom(),
ffn_box.get_top(),
buff=0.1,
color=WHITE
)
# 显示所有组件
components = [
encoder_layer, input_arrow, input_label,
ln1_box, ln1_text, arrow_ln1_mha,
mha_box, mha_text, residual1, residual1_label,
arrow_mha_ln2, ln2_box, ln2_text,
arrow_ln2_ffn, ffn_box, ffn_text,
residual2, residual2_label, output_arrow, output_label
]
for comp in components:
self.play(Create(comp) if not isinstance(comp, Text) else Write(comp),
run_time=0.3)
# 添加说明
explanation = Text(
"编码器层 = 层归一化 + 多头注意力 + 前馈网络(均有残差连接)",
font_size=20,
font="SimHei",
color=YELLOW
)
explanation.to_edge(DOWN, buff=1.0)
self.play(Write(explanation), run_time=1.0)
self.wait(2.0)
self.encoder_title = encoder_title
self.encoder_components = components + [explanation]
def show_decoder_details(self):
"""显示解码器层细节"""
# 清除编码器
self.play(
FadeOut(self.encoder_title),
*[FadeOut(c) for c in self.encoder_components],
run_time=0.8
)
# 解码器详细结构
decoder_title = Text(
"解码器层内部结构",
font_size=36,
font="SimHei",
color=ORANGE
)
decoder_title.to_edge(UP, buff=0.5)
self.play(Write(decoder_title), run_time=0.8)
# 解码器层框
decoder_layer = Rectangle(
width=6.0, height=8.0,
color=ORANGE,
fill_opacity=0.1
)
decoder_layer.center().shift(UP * 0.5)
# 输入箭头
input_arrow = Arrow(
decoder_layer.get_top() + UP * 0.5,
decoder_layer.get_top(),
color=GREEN
)
input_label = Text(
"输入",
font_size=18,
font="SimHei",
color=GREEN
)
input_label.next_to(input_arrow, UP, buff=0.1)
# 掩码多头注意力
masked_mha_box = Rectangle(
width=4.0, height=1.2,
color=YELLOW,
fill_opacity=0.3
)
masked_mha_box.move_to(decoder_layer.get_top() + DOWN * 1.5)
masked_mha_text = Text(
"掩码多头注意力\n(Masked Multi-Head Attention)",
font_size=16,
font="SimHei",
color=WHITE
)
masked_mha_text.move_to(masked_mha_box)
# 编码器-解码器注意力
enc_dec_box = Rectangle(
width=4.0, height=1.2,
color=PURPLE,
fill_opacity=0.3
)
enc_dec_box.move_to(masked_mha_box.get_bottom() + DOWN * 1.5)
enc_dec_text = Text(
"编码器-解码器注意力\n(Encoder-Decoder Attention)",
font_size=16,
font="SimHei",
color=WHITE
)
enc_dec_text.move_to(enc_dec_box)
# 前馈网络
ffn_box = Rectangle(
width=4.0, height=1.2,
color=GREEN,
fill_opacity=0.3
)
ffn_box.move_to(enc_dec_box.get_bottom() + DOWN * 1.5)
ffn_text = Text(
"前馈神经网络\n(Feed Forward Network)",
font_size=16,
font="SimHei",
color=WHITE
)
ffn_text.move_to(ffn_box)
# 层归一化(三个)
ln_positions = [
masked_mha_box.get_top() + DOWN * 0.2,
enc_dec_box.get_top() + DOWN * 0.2,
ffn_box.get_top() + DOWN * 0.2
]
ln_boxes = []
for i, pos in enumerate(ln_positions):
ln_box = Rectangle(
width=1.0, height=0.5,
color=BLUE,
fill_opacity=0.3
)
ln_box.move_to(pos + LEFT * 2.5)
ln_text = Text(
"LN",
font_size=12,
font="SimHei",
color=WHITE
)
ln_text.move_to(ln_box)
ln_boxes.append(VGroup(ln_box, ln_text))
# 残差连接
residual_arrows = []
for i, ln_box in enumerate(ln_boxes):
arrow = Arrow(
ln_box[0].get_left() + LEFT * 0.3,
[masked_mha_box, enc_dec_box, ffn_box][i].get_right() + RIGHT * 0.3,
color=RED,
buff=0.1
)
residual_arrows.append(arrow)
# 输出箭头
output_arrow = Arrow(
decoder_layer.get_bottom(),
decoder_layer.get_bottom() + DOWN * 0.5,
color=GREEN
)
output_label = Text(
"输出",
font_size=18,
font="SimHei",
color=GREEN
)
output_label.next_to(output_arrow, DOWN, buff=0.1)
# 连接箭头
arrows = []
arrow1 = Arrow(
input_arrow.get_end(),
ln_boxes[0][0].get_top(),
buff=0.1,
color=WHITE
)
arrow2 = Arrow(
masked_mha_box.get_bottom(),
ln_boxes[1][0].get_top(),
buff=0.1,
color=WHITE
)
arrow3 = Arrow(
enc_dec_box.get_bottom(),
ln_boxes[2][0].get_top(),
buff=0.1,
color=WHITE
)
arrow4 = Arrow(
ffn_box.get_bottom(),
output_arrow.get_start(),
buff=0.1,
color=WHITE
)
# 显示所有组件
components = [
decoder_layer, input_arrow, input_label, arrow1,
masked_mha_box, masked_mha_text,
enc_dec_box, enc_dec_text, ffn_box, ffn_text,
output_arrow, output_label
]
for ln in ln_boxes:
components.append(ln)
for arrow in [arrow2, arrow3, arrow4] + residual_arrows:
components.append(arrow)
for comp in components:
self.play(Create(comp) if not isinstance(comp, Text) and not isinstance(comp, VGroup) else Write(comp),
run_time=0.2)
# 添加说明
explanation = Text(
"解码器层 = 掩码注意力 + 编码器-解码器注意力 + 前馈网络",
font_size=20,
font="SimHei",
color=YELLOW
)
explanation.to_edge(DOWN, buff=1.0)
self.play(Write(explanation), run_time=1.0)
self.wait(2.0)
self.decoder_title = decoder_title
self.decoder_components = components + [explanation]
def show_attention_mechanism(self):
"""显示注意力机制"""
# 清除解码器
self.play(
FadeOut(self.decoder_title),
*[FadeOut(c) for c in self.decoder_components],
run_time=0.8
)
# 注意力机制标题
attn_title = Text(
"注意力机制 (Attention)",
font_size=36,
font="SimHei",
color=YELLOW
)
attn_title.to_edge(UP, buff=0.5)
self.play(Write(attn_title), run_time=0.8)
# 注意力公式
formula = MathTex(
r"\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V",
font_size=32
)
formula.shift(UP * 1.0)
self.play(Write(formula), run_time=1.5)
self.wait(1.0)
# 多头注意力公式
multi_head_formula = MathTex(
r"\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O",
font_size=28
)
multi_head_formula.shift(DOWN * 0.5)
head_formula = MathTex(
r"\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)",
font_size=24
)
head_formula.shift(DOWN * 1.5)
self.play(Write(multi_head_formula), run_time=1.0)
self.play(Write(head_formula), run_time=1.0)
# 注意力可视化
explanation = Text(
"注意力允许模型关注输入序列中的所有位置",
font_size=22,
font="SimHei",
color=BLUE
)
explanation.to_edge(DOWN, buff=1.0)
self.play(Write(explanation), run_time=1.0)
self.wait(2.0)
self.attention_components = [attn_title, formula, multi_head_formula, head_formula, explanation]
def show_attention_formula(self):
"""显示注意力公式的详细解释"""
# 清除之前的注意力
self.play(
*[FadeOut(c) for c in self.attention_components],
run_time=0.8
)
# 公式详细解释
formula_title = Text(
"缩放点积注意力详解",
font_size=36,
font="SimHei",
color=PURPLE
)
formula_title.to_edge(UP, buff=0.5)
self.play(Write(formula_title), run_time=0.8)
# 分步公式 - 使用英文避免LaTeX编译错误
step1_formula = MathTex(
r"1.\ QK^T",
font_size=26
)
step1_formula.shift(UP * 2.5)
step1_text = Text(
"(计算相似度)",
font_size=20,
font="SimHei",
color=GRAY
)
step1_text.next_to(step1_formula, DOWN, buff=0.1)
step2_formula = MathTex(
r"2.\ \frac{QK^T}{\sqrt{d_k}}",
font_size=26
)
step2_formula.shift(UP * 1.0)
step2_text = Text(
"(缩放,稳定梯度)",
font_size=20,
font="SimHei",
color=GRAY
)
step2_text.next_to(step2_formula, DOWN, buff=0.1)
step3_formula = MathTex(
r"3.\ \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)",
font_size=26
)
step3_formula.shift(DOWN * 0.5)
step3_text = Text(
"(归一化为概率)",
font_size=20,
font="SimHei",
color=GRAY
)
step3_text.next_to(step3_formula, DOWN, buff=0.1)
step4_formula = MathTex(
r"4.\ \text{softmax}(\cdots)V",
font_size=26
)
step4_formula.shift(DOWN * 2.0)
step4_text = Text(
"(加权求和)",
font_size=20,
font="SimHei",
color=GRAY
)
step4_text.next_to(step4_formula, DOWN, buff=0.1)
formula_steps = [step1_formula, step2_formula, step3_formula, step4_formula]
text_steps = [step1_text, step2_text, step3_text, step4_text]
for i in range(4):
self.play(Write(formula_steps[i]), run_time=0.8)
self.play(Write(text_steps[i]), run_time=0.5)
self.wait(0.3)
# AstrAI中的注意力优化
astrai_title = Text(
"AstrAI中的注意力优化",
font_size=28,
font="SimHei",
color=GREEN
)
astrai_title.shift(DOWN * 2.5)
optimizations = [
"• Flash Attention内存高效实现",
"• KV Cache避免重复计算",
"• 连续批处理:动态请求调度",
"• 前缀缓存Radix Tree管理",
]
opt_texts = []
for i, opt in enumerate(optimizations):
text = Text(
opt,
font_size=20,
font="SimHei",
color=YELLOW
)
text.next_to(astrai_title, DOWN, buff=0.3 + i * 0.4)
text.align_to(astrai_title, LEFT)
opt_texts.append(text)
self.play(Write(astrai_title), run_time=0.8)
for text in opt_texts:
self.play(Write(text), run_time=0.5)
self.wait(3.0)
# 渲染命令
# python -m manim transformer_visualization.py TransformerArchitecture -pqh
if __name__ == "__main__":
# 快速测试
scene = TransformerArchitecture()
scene.render()