AstrAI-video-repo/transformer_visualization.py

795 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Transformer 架构可视化
使用Manim绘制Transformer的详细结构图
"""
from manim import *
import numpy as np
class TransformerArchitecture(Scene):
"""Transformer完整架构可视化"""
def construct(self):
self.show_title()
self.show_overall_architecture()
self.show_encoder_details()
self.show_decoder_details()
self.show_attention_mechanism()
self.show_attention_formula()
def show_title(self):
"""显示标题"""
title = Text(
"Transformer 架构详解",
font_size=48,
font="SimHei",
color=WHITE
)
title.to_edge(UP, buff=0.5)
subtitle = Text(
"现代大语言模型的核心基础",
font_size=24,
font="SimHei",
color=GRAY
)
subtitle.next_to(title, DOWN, buff=0.3)
self.play(Write(title), run_time=1.0)
self.play(Write(subtitle), run_time=0.8)
self.wait(1.0)
self.title = title
self.subtitle = subtitle
def show_overall_architecture(self):
"""显示整体架构图"""
# 清除标题
self.play(
FadeOut(self.title),
FadeOut(self.subtitle),
run_time=0.5
)
# 整体架构标题
arch_title = Text(
"Transformer 整体架构",
font_size=36,
font="SimHei",
color=BLUE
)
arch_title.to_edge(UP, buff=0.5)
self.play(Write(arch_title), run_time=0.8)
# 创建架构图的主要组件
# 输入部分
input_box = Rectangle(
width=3.0, height=1.0,
color=GREEN,
fill_opacity=0.3
)
input_box.to_edge(LEFT, buff=2.0).shift(UP * 0.5)
input_text = Text(
"输入序列",
font_size=20,
font="SimHei",
color=WHITE
)
input_text.move_to(input_box)
# 编码器堆叠
encoder_stack = VGroup()
for i in range(6):
encoder = Rectangle(
width=3.0, height=0.5,
color=BLUE_C,
fill_opacity=0.2
)
encoder.shift(RIGHT * 2.5 + UP * (1.5 - i * 0.55))
encoder_label = Text(
f"编码器层 {i+1}",
font_size=14,
font="SimHei"
)
encoder_label.move_to(encoder)
encoder_group = VGroup(encoder, encoder_label)
encoder_stack.add(encoder_group)
# 解码器堆叠
decoder_stack = VGroup()
for i in range(6):
decoder = Rectangle(
width=3.0, height=0.5,
color=ORANGE,
fill_opacity=0.2
)
decoder.shift(RIGHT * 6.5 + UP * (1.5 - i * 0.55))
decoder_label = Text(
f"解码器层 {i+1}",
font_size=14,
font="SimHei"
)
decoder_label.move_to(decoder)
decoder_group = VGroup(decoder, decoder_label)
decoder_stack.add(decoder_group)
# 输出部分
output_box = Rectangle(
width=3.0, height=1.0,
color=RED,
fill_opacity=0.3
)
output_box.to_edge(RIGHT, buff=2.0).shift(UP * 0.5)
output_text = Text(
"输出序列",
font_size=20,
font="SimHei",
color=WHITE
)
output_text.move_to(output_box)
# 箭头连接
arrow1 = Arrow(
input_box.get_right(),
encoder_stack.get_left(),
buff=0.2,
color=YELLOW
)
arrow2 = Arrow(
encoder_stack.get_right(),
decoder_stack.get_left(),
buff=0.2,
color=YELLOW
)
arrow3 = Arrow(
decoder_stack.get_right(),
output_box.get_left(),
buff=0.2,
color=YELLOW
)
# 显示所有组件
self.play(
Create(input_box),
Write(input_text),
run_time=0.8
)
self.wait(0.3)
# 显示编码器
for i, encoder in enumerate(encoder_stack):
self.play(
Create(encoder),
run_time=0.2
)
if i == 0:
self.play(Create(arrow1), run_time=0.5)
self.wait(0.3)
# 显示解码器
for i, decoder in enumerate(decoder_stack):
self.play(
Create(decoder),
run_time=0.2
)
if i == 0:
self.play(Create(arrow2), run_time=0.5)
self.wait(0.3)
# 显示输出
self.play(
Create(output_box),
Write(output_text),
Create(arrow3),
run_time=0.8
)
# 添加说明
explanation = Text(
"Transformer = 编码器(N层) + 解码器(N层)",
font_size=22,
font="SimHei",
color=YELLOW
)
explanation.to_edge(DOWN, buff=1.0)
self.play(Write(explanation), run_time=1.0)
self.wait(2.0)
# 保存引用
self.arch_title = arch_title
self.arch_components = [input_box, input_text, encoder_stack, decoder_stack,
output_box, output_text, arrow1, arrow2, arrow3, explanation]
def show_encoder_details(self):
"""显示编码器层细节"""
# 清除整体架构
self.play(
FadeOut(self.arch_title),
*[FadeOut(c) for c in self.arch_components],
run_time=0.8
)
# 编码器详细结构
encoder_title = Text(
"编码器层内部结构",
font_size=36,
font="SimHei",
color=BLUE
)
encoder_title.to_edge(UP, buff=0.5)
self.play(Write(encoder_title), run_time=0.8)
# 编码器层框
encoder_layer = Rectangle(
width=5.0, height=6.0,
color=BLUE,
fill_opacity=0.1
)
encoder_layer.center()
# 输入箭头
input_arrow = Arrow(
encoder_layer.get_top() + UP * 0.5,
encoder_layer.get_top(),
color=GREEN
)
input_label = Text(
"输入",
font_size=18,
font="SimHei",
color=GREEN
)
input_label.next_to(input_arrow, UP, buff=0.1)
# 层归一化1
ln1_box = Rectangle(
width=4.0, height=0.8,
color=PURPLE,
fill_opacity=0.3
)
ln1_box.move_to(encoder_layer.get_top() + DOWN * 1.0)
ln1_text = Text(
"层归一化 (LayerNorm)",
font_size=16,
font="SimHei",
color=WHITE
)
ln1_text.move_to(ln1_box)
# 多头注意力
mha_box = Rectangle(
width=4.0, height=1.2,
color=YELLOW,
fill_opacity=0.3
)
mha_box.move_to(ln1_box.get_bottom() + DOWN * 1.0)
mha_text = Text(
"多头自注意力\n(Multi-Head Attention)",
font_size=16,
font="SimHei",
color=WHITE
)
mha_text.move_to(mha_box)
# 残差连接1 - 从输入到多头注意力输出(右侧垂直箭头)
# 使用相同的x坐标确保箭头垂直
right_side_x = ln1_box.get_right()[0] + 1.5
residual1_start = np.array([right_side_x, ln1_box.get_top()[1] + 0.1, 0])
residual1_end = np.array([right_side_x, mha_box.get_bottom()[1] - 0.1, 0])
residual1 = Arrow(
residual1_start,
residual1_end,
color=RED,
buff=0.1
)
residual1_label = Text(
"残差连接",
font_size=14,
font="SimHei",
color=RED
)
residual1_label.next_to(residual1, RIGHT, buff=0.1)
# 层归一化2
ln2_box = Rectangle(
width=4.0, height=0.8,
color=PURPLE,
fill_opacity=0.3
)
ln2_box.move_to(mha_box.get_bottom() + DOWN * 1.5)
ln2_text = Text(
"层归一化 (LayerNorm)",
font_size=16,
font="SimHei",
color=WHITE
)
ln2_text.move_to(ln2_box)
# 前馈网络
ffn_box = Rectangle(
width=4.0, height=1.2,
color=GREEN,
fill_opacity=0.3
)
ffn_box.move_to(ln2_box.get_bottom() + DOWN * 1.0)
ffn_text = Text(
"前馈神经网络\n(Feed Forward Network)",
font_size=16,
font="SimHei",
color=WHITE
)
ffn_text.move_to(ffn_box)
# 残差连接2 - 从层归一化2输入到前馈网络输出右侧垂直箭头
# 使用相同的x坐标确保箭头垂直
right_side_x2 = ln2_box.get_right()[0] + 1.5
residual2_start = np.array([right_side_x2, ln2_box.get_top()[1] + 0.1, 0])
residual2_end = np.array([right_side_x2, ffn_box.get_bottom()[1] - 0.1, 0])
residual2 = Arrow(
residual2_start,
residual2_end,
color=RED,
buff=0.1
)
residual2_label = Text(
"残差连接",
font_size=14,
font="SimHei",
color=RED
)
residual2_label.next_to(residual2, RIGHT, buff=0.1)
# 输出箭头
output_arrow = Arrow(
encoder_layer.get_bottom(),
encoder_layer.get_bottom() + DOWN * 0.5,
color=GREEN
)
output_label = Text(
"输出",
font_size=18,
font="SimHei",
color=GREEN
)
output_label.next_to(output_arrow, DOWN, buff=0.1)
# 连接箭头
arrow_ln1_mha = Arrow(
ln1_box.get_bottom(),
mha_box.get_top(),
buff=0.1,
color=WHITE
)
arrow_mha_ln2 = Arrow(
mha_box.get_bottom(),
ln2_box.get_top(),
buff=0.1,
color=WHITE
)
arrow_ln2_ffn = Arrow(
ln2_box.get_bottom(),
ffn_box.get_top(),
buff=0.1,
color=WHITE
)
# 显示所有组件
components = [
encoder_layer, input_arrow, input_label,
ln1_box, ln1_text, arrow_ln1_mha,
mha_box, mha_text, residual1, residual1_label,
arrow_mha_ln2, ln2_box, ln2_text,
arrow_ln2_ffn, ffn_box, ffn_text,
residual2, residual2_label, output_arrow, output_label
]
for comp in components:
self.play(Create(comp) if not isinstance(comp, Text) else Write(comp),
run_time=0.3)
# 添加说明
explanation = Text(
"编码器层 = 层归一化 + 多头注意力 + 前馈网络(均有残差连接)",
font_size=20,
font="SimHei",
color=YELLOW
)
explanation.to_edge(DOWN, buff=1.0)
self.play(Write(explanation), run_time=1.0)
self.wait(2.0)
self.encoder_title = encoder_title
self.encoder_components = components + [explanation]
def show_decoder_details(self):
"""显示解码器层细节"""
# 清除编码器
self.play(
FadeOut(self.encoder_title),
*[FadeOut(c) for c in self.encoder_components],
run_time=0.8
)
# 解码器详细结构
decoder_title = Text(
"解码器层内部结构",
font_size=36,
font="SimHei",
color=ORANGE
)
decoder_title.to_edge(UP, buff=0.5)
self.play(Write(decoder_title), run_time=0.8)
# 解码器层框
decoder_layer = Rectangle(
width=6.0, height=8.0,
color=ORANGE,
fill_opacity=0.1
)
decoder_layer.center().shift(UP * 0.5)
# 输入箭头
input_arrow = Arrow(
decoder_layer.get_top() + UP * 0.5,
decoder_layer.get_top(),
color=GREEN
)
input_label = Text(
"输入",
font_size=18,
font="SimHei",
color=GREEN
)
input_label.next_to(input_arrow, UP, buff=0.1)
# 掩码多头注意力
masked_mha_box = Rectangle(
width=4.0, height=1.2,
color=YELLOW,
fill_opacity=0.3
)
masked_mha_box.move_to(decoder_layer.get_top() + DOWN * 1.5)
masked_mha_text = Text(
"掩码多头注意力\n(Masked Multi-Head Attention)",
font_size=16,
font="SimHei",
color=WHITE
)
masked_mha_text.move_to(masked_mha_box)
# 编码器-解码器注意力
enc_dec_box = Rectangle(
width=4.0, height=1.2,
color=PURPLE,
fill_opacity=0.3
)
enc_dec_box.move_to(masked_mha_box.get_bottom() + DOWN * 1.5)
enc_dec_text = Text(
"编码器-解码器注意力\n(Encoder-Decoder Attention)",
font_size=16,
font="SimHei",
color=WHITE
)
enc_dec_text.move_to(enc_dec_box)
# 前馈网络
ffn_box = Rectangle(
width=4.0, height=1.2,
color=GREEN,
fill_opacity=0.3
)
ffn_box.move_to(enc_dec_box.get_bottom() + DOWN * 1.5)
ffn_text = Text(
"前馈神经网络\n(Feed Forward Network)",
font_size=16,
font="SimHei",
color=WHITE
)
ffn_text.move_to(ffn_box)
# 层归一化(三个)
ln_positions = [
masked_mha_box.get_top() + DOWN * 0.2,
enc_dec_box.get_top() + DOWN * 0.2,
ffn_box.get_top() + DOWN * 0.2
]
ln_boxes = []
for i, pos in enumerate(ln_positions):
ln_box = Rectangle(
width=1.0, height=0.5,
color=BLUE,
fill_opacity=0.3
)
ln_box.move_to(pos + LEFT * 2.5)
ln_text = Text(
"LN",
font_size=12,
font="SimHei",
color=WHITE
)
ln_text.move_to(ln_box)
ln_boxes.append(VGroup(ln_box, ln_text))
# 残差连接
residual_arrows = []
for i, ln_box in enumerate(ln_boxes):
arrow = Arrow(
ln_box[0].get_left() + LEFT * 0.3,
[masked_mha_box, enc_dec_box, ffn_box][i].get_right() + RIGHT * 0.3,
color=RED,
buff=0.1
)
residual_arrows.append(arrow)
# 输出箭头
output_arrow = Arrow(
decoder_layer.get_bottom(),
decoder_layer.get_bottom() + DOWN * 0.5,
color=GREEN
)
output_label = Text(
"输出",
font_size=18,
font="SimHei",
color=GREEN
)
output_label.next_to(output_arrow, DOWN, buff=0.1)
# 连接箭头
arrows = []
arrow1 = Arrow(
input_arrow.get_end(),
ln_boxes[0][0].get_top(),
buff=0.1,
color=WHITE
)
arrow2 = Arrow(
masked_mha_box.get_bottom(),
ln_boxes[1][0].get_top(),
buff=0.1,
color=WHITE
)
arrow3 = Arrow(
enc_dec_box.get_bottom(),
ln_boxes[2][0].get_top(),
buff=0.1,
color=WHITE
)
arrow4 = Arrow(
ffn_box.get_bottom(),
output_arrow.get_start(),
buff=0.1,
color=WHITE
)
# 显示所有组件
components = [
decoder_layer, input_arrow, input_label, arrow1,
masked_mha_box, masked_mha_text,
enc_dec_box, enc_dec_text, ffn_box, ffn_text,
output_arrow, output_label
]
for ln in ln_boxes:
components.append(ln)
for arrow in [arrow2, arrow3, arrow4] + residual_arrows:
components.append(arrow)
for comp in components:
self.play(Create(comp) if not isinstance(comp, Text) and not isinstance(comp, VGroup) else Write(comp),
run_time=0.2)
# 添加说明
explanation = Text(
"解码器层 = 掩码注意力 + 编码器-解码器注意力 + 前馈网络",
font_size=20,
font="SimHei",
color=YELLOW
)
explanation.to_edge(DOWN, buff=1.0)
self.play(Write(explanation), run_time=1.0)
self.wait(2.0)
self.decoder_title = decoder_title
self.decoder_components = components + [explanation]
def show_attention_mechanism(self):
"""显示注意力机制"""
# 清除解码器
self.play(
FadeOut(self.decoder_title),
*[FadeOut(c) for c in self.decoder_components],
run_time=0.8
)
# 注意力机制标题
attn_title = Text(
"注意力机制 (Attention)",
font_size=36,
font="SimHei",
color=YELLOW
)
attn_title.to_edge(UP, buff=0.5)
self.play(Write(attn_title), run_time=0.8)
# 注意力公式
formula = MathTex(
r"\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V",
font_size=32
)
formula.shift(UP * 1.0)
self.play(Write(formula), run_time=1.5)
self.wait(1.0)
# 多头注意力公式
multi_head_formula = MathTex(
r"\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O",
font_size=28
)
multi_head_formula.shift(DOWN * 0.5)
head_formula = MathTex(
r"\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)",
font_size=24
)
head_formula.shift(DOWN * 1.5)
self.play(Write(multi_head_formula), run_time=1.0)
self.play(Write(head_formula), run_time=1.0)
# 注意力可视化
explanation = Text(
"注意力允许模型关注输入序列中的所有位置",
font_size=22,
font="SimHei",
color=BLUE
)
explanation.to_edge(DOWN, buff=1.0)
self.play(Write(explanation), run_time=1.0)
self.wait(2.0)
self.attention_components = [attn_title, formula, multi_head_formula, head_formula, explanation]
def show_attention_formula(self):
"""显示注意力公式的详细解释"""
# 清除之前的注意力
self.play(
*[FadeOut(c) for c in self.attention_components],
run_time=0.8
)
# 公式详细解释
formula_title = Text(
"缩放点积注意力详解",
font_size=36,
font="SimHei",
color=PURPLE
)
formula_title.to_edge(UP, buff=0.5)
self.play(Write(formula_title), run_time=0.8)
# 分步公式 - 使用英文避免LaTeX编译错误
step1_formula = MathTex(
r"1.\ QK^T",
font_size=26
)
step1_formula.shift(UP * 2.5)
step1_text = Text(
"(计算相似度)",
font_size=20,
font="SimHei",
color=GRAY
)
step1_text.next_to(step1_formula, DOWN, buff=0.1)
step2_formula = MathTex(
r"2.\ \frac{QK^T}{\sqrt{d_k}}",
font_size=26
)
step2_formula.shift(UP * 1.0)
step2_text = Text(
"(缩放,稳定梯度)",
font_size=20,
font="SimHei",
color=GRAY
)
step2_text.next_to(step2_formula, DOWN, buff=0.1)
step3_formula = MathTex(
r"3.\ \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)",
font_size=26
)
step3_formula.shift(DOWN * 0.5)
step3_text = Text(
"(归一化为概率)",
font_size=20,
font="SimHei",
color=GRAY
)
step3_text.next_to(step3_formula, DOWN, buff=0.1)
step4_formula = MathTex(
r"4.\ \text{softmax}(\cdots)V",
font_size=26
)
step4_formula.shift(DOWN * 2.0)
step4_text = Text(
"(加权求和)",
font_size=20,
font="SimHei",
color=GRAY
)
step4_text.next_to(step4_formula, DOWN, buff=0.1)
formula_steps = [step1_formula, step2_formula, step3_formula, step4_formula]
text_steps = [step1_text, step2_text, step3_text, step4_text]
for i in range(4):
self.play(Write(formula_steps[i]), run_time=0.8)
self.play(Write(text_steps[i]), run_time=0.5)
self.wait(0.3)
# AstrAI中的注意力优化
astrai_title = Text(
"AstrAI中的注意力优化",
font_size=28,
font="SimHei",
color=GREEN
)
astrai_title.shift(DOWN * 2.5)
optimizations = [
"• Flash Attention内存高效实现",
"• KV Cache避免重复计算",
"• 连续批处理:动态请求调度",
"• 前缀缓存Radix Tree管理",
]
opt_texts = []
for i, opt in enumerate(optimizations):
text = Text(
opt,
font_size=20,
font="SimHei",
color=YELLOW
)
text.next_to(astrai_title, DOWN, buff=0.3 + i * 0.4)
text.align_to(astrai_title, LEFT)
opt_texts.append(text)
self.play(Write(astrai_title), run_time=0.8)
for text in opt_texts:
self.play(Write(text), run_time=0.5)
self.wait(3.0)
# 渲染命令
# python -m manim transformer_visualization.py TransformerArchitecture -pqh
if __name__ == "__main__":
# 快速测试
scene = TransformerArchitecture()
scene.render()