795 lines
23 KiB
Python
795 lines
23 KiB
Python
"""
|
||
Transformer 架构可视化
|
||
使用Manim绘制Transformer的详细结构图
|
||
"""
|
||
from manim import *
|
||
import numpy as np
|
||
|
||
|
||
class TransformerArchitecture(Scene):
|
||
"""Transformer完整架构可视化"""
|
||
|
||
def construct(self):
|
||
self.show_title()
|
||
self.show_overall_architecture()
|
||
self.show_encoder_details()
|
||
self.show_decoder_details()
|
||
self.show_attention_mechanism()
|
||
self.show_attention_formula()
|
||
|
||
def show_title(self):
|
||
"""显示标题"""
|
||
title = Text(
|
||
"Transformer 架构详解",
|
||
font_size=48,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
title.to_edge(UP, buff=0.5)
|
||
|
||
subtitle = Text(
|
||
"现代大语言模型的核心基础",
|
||
font_size=24,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
subtitle.next_to(title, DOWN, buff=0.3)
|
||
|
||
self.play(Write(title), run_time=1.0)
|
||
self.play(Write(subtitle), run_time=0.8)
|
||
self.wait(1.0)
|
||
|
||
self.title = title
|
||
self.subtitle = subtitle
|
||
|
||
def show_overall_architecture(self):
|
||
"""显示整体架构图"""
|
||
# 清除标题
|
||
self.play(
|
||
FadeOut(self.title),
|
||
FadeOut(self.subtitle),
|
||
run_time=0.5
|
||
)
|
||
|
||
# 整体架构标题
|
||
arch_title = Text(
|
||
"Transformer 整体架构",
|
||
font_size=36,
|
||
font="SimHei",
|
||
color=BLUE
|
||
)
|
||
arch_title.to_edge(UP, buff=0.5)
|
||
self.play(Write(arch_title), run_time=0.8)
|
||
|
||
# 创建架构图的主要组件
|
||
# 输入部分
|
||
input_box = Rectangle(
|
||
width=3.0, height=1.0,
|
||
color=GREEN,
|
||
fill_opacity=0.3
|
||
)
|
||
input_box.to_edge(LEFT, buff=2.0).shift(UP * 0.5)
|
||
|
||
input_text = Text(
|
||
"输入序列",
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
input_text.move_to(input_box)
|
||
|
||
# 编码器堆叠
|
||
encoder_stack = VGroup()
|
||
for i in range(6):
|
||
encoder = Rectangle(
|
||
width=3.0, height=0.5,
|
||
color=BLUE_C,
|
||
fill_opacity=0.2
|
||
)
|
||
encoder.shift(RIGHT * 2.5 + UP * (1.5 - i * 0.55))
|
||
encoder_label = Text(
|
||
f"编码器层 {i+1}",
|
||
font_size=14,
|
||
font="SimHei"
|
||
)
|
||
encoder_label.move_to(encoder)
|
||
encoder_group = VGroup(encoder, encoder_label)
|
||
encoder_stack.add(encoder_group)
|
||
|
||
# 解码器堆叠
|
||
decoder_stack = VGroup()
|
||
for i in range(6):
|
||
decoder = Rectangle(
|
||
width=3.0, height=0.5,
|
||
color=ORANGE,
|
||
fill_opacity=0.2
|
||
)
|
||
decoder.shift(RIGHT * 6.5 + UP * (1.5 - i * 0.55))
|
||
decoder_label = Text(
|
||
f"解码器层 {i+1}",
|
||
font_size=14,
|
||
font="SimHei"
|
||
)
|
||
decoder_label.move_to(decoder)
|
||
decoder_group = VGroup(decoder, decoder_label)
|
||
decoder_stack.add(decoder_group)
|
||
|
||
# 输出部分
|
||
output_box = Rectangle(
|
||
width=3.0, height=1.0,
|
||
color=RED,
|
||
fill_opacity=0.3
|
||
)
|
||
output_box.to_edge(RIGHT, buff=2.0).shift(UP * 0.5)
|
||
|
||
output_text = Text(
|
||
"输出序列",
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
output_text.move_to(output_box)
|
||
|
||
# 箭头连接
|
||
arrow1 = Arrow(
|
||
input_box.get_right(),
|
||
encoder_stack.get_left(),
|
||
buff=0.2,
|
||
color=YELLOW
|
||
)
|
||
|
||
arrow2 = Arrow(
|
||
encoder_stack.get_right(),
|
||
decoder_stack.get_left(),
|
||
buff=0.2,
|
||
color=YELLOW
|
||
)
|
||
|
||
arrow3 = Arrow(
|
||
decoder_stack.get_right(),
|
||
output_box.get_left(),
|
||
buff=0.2,
|
||
color=YELLOW
|
||
)
|
||
|
||
# 显示所有组件
|
||
self.play(
|
||
Create(input_box),
|
||
Write(input_text),
|
||
run_time=0.8
|
||
)
|
||
self.wait(0.3)
|
||
|
||
# 显示编码器
|
||
for i, encoder in enumerate(encoder_stack):
|
||
self.play(
|
||
Create(encoder),
|
||
run_time=0.2
|
||
)
|
||
if i == 0:
|
||
self.play(Create(arrow1), run_time=0.5)
|
||
|
||
self.wait(0.3)
|
||
|
||
# 显示解码器
|
||
for i, decoder in enumerate(decoder_stack):
|
||
self.play(
|
||
Create(decoder),
|
||
run_time=0.2
|
||
)
|
||
if i == 0:
|
||
self.play(Create(arrow2), run_time=0.5)
|
||
|
||
self.wait(0.3)
|
||
|
||
# 显示输出
|
||
self.play(
|
||
Create(output_box),
|
||
Write(output_text),
|
||
Create(arrow3),
|
||
run_time=0.8
|
||
)
|
||
|
||
# 添加说明
|
||
explanation = Text(
|
||
"Transformer = 编码器(N层) + 解码器(N层)",
|
||
font_size=22,
|
||
font="SimHei",
|
||
color=YELLOW
|
||
)
|
||
explanation.to_edge(DOWN, buff=1.0)
|
||
|
||
self.play(Write(explanation), run_time=1.0)
|
||
self.wait(2.0)
|
||
|
||
# 保存引用
|
||
self.arch_title = arch_title
|
||
self.arch_components = [input_box, input_text, encoder_stack, decoder_stack,
|
||
output_box, output_text, arrow1, arrow2, arrow3, explanation]
|
||
|
||
def show_encoder_details(self):
|
||
"""显示编码器层细节"""
|
||
# 清除整体架构
|
||
self.play(
|
||
FadeOut(self.arch_title),
|
||
*[FadeOut(c) for c in self.arch_components],
|
||
run_time=0.8
|
||
)
|
||
|
||
# 编码器详细结构
|
||
encoder_title = Text(
|
||
"编码器层内部结构",
|
||
font_size=36,
|
||
font="SimHei",
|
||
color=BLUE
|
||
)
|
||
encoder_title.to_edge(UP, buff=0.5)
|
||
self.play(Write(encoder_title), run_time=0.8)
|
||
|
||
# 编码器层框
|
||
encoder_layer = Rectangle(
|
||
width=5.0, height=6.0,
|
||
color=BLUE,
|
||
fill_opacity=0.1
|
||
)
|
||
encoder_layer.center()
|
||
|
||
# 输入箭头
|
||
input_arrow = Arrow(
|
||
encoder_layer.get_top() + UP * 0.5,
|
||
encoder_layer.get_top(),
|
||
color=GREEN
|
||
)
|
||
input_label = Text(
|
||
"输入",
|
||
font_size=18,
|
||
font="SimHei",
|
||
color=GREEN
|
||
)
|
||
input_label.next_to(input_arrow, UP, buff=0.1)
|
||
|
||
# 层归一化1
|
||
ln1_box = Rectangle(
|
||
width=4.0, height=0.8,
|
||
color=PURPLE,
|
||
fill_opacity=0.3
|
||
)
|
||
ln1_box.move_to(encoder_layer.get_top() + DOWN * 1.0)
|
||
ln1_text = Text(
|
||
"层归一化 (LayerNorm)",
|
||
font_size=16,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
ln1_text.move_to(ln1_box)
|
||
|
||
# 多头注意力
|
||
mha_box = Rectangle(
|
||
width=4.0, height=1.2,
|
||
color=YELLOW,
|
||
fill_opacity=0.3
|
||
)
|
||
mha_box.move_to(ln1_box.get_bottom() + DOWN * 1.0)
|
||
mha_text = Text(
|
||
"多头自注意力\n(Multi-Head Attention)",
|
||
font_size=16,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
mha_text.move_to(mha_box)
|
||
|
||
# 残差连接1 - 从输入到多头注意力输出(右侧垂直箭头)
|
||
# 使用相同的x坐标确保箭头垂直
|
||
right_side_x = ln1_box.get_right()[0] + 1.5
|
||
residual1_start = np.array([right_side_x, ln1_box.get_top()[1] + 0.1, 0])
|
||
residual1_end = np.array([right_side_x, mha_box.get_bottom()[1] - 0.1, 0])
|
||
residual1 = Arrow(
|
||
residual1_start,
|
||
residual1_end,
|
||
color=RED,
|
||
buff=0.1
|
||
)
|
||
residual1_label = Text(
|
||
"残差连接",
|
||
font_size=14,
|
||
font="SimHei",
|
||
color=RED
|
||
)
|
||
residual1_label.next_to(residual1, RIGHT, buff=0.1)
|
||
|
||
# 层归一化2
|
||
ln2_box = Rectangle(
|
||
width=4.0, height=0.8,
|
||
color=PURPLE,
|
||
fill_opacity=0.3
|
||
)
|
||
ln2_box.move_to(mha_box.get_bottom() + DOWN * 1.5)
|
||
ln2_text = Text(
|
||
"层归一化 (LayerNorm)",
|
||
font_size=16,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
ln2_text.move_to(ln2_box)
|
||
|
||
# 前馈网络
|
||
ffn_box = Rectangle(
|
||
width=4.0, height=1.2,
|
||
color=GREEN,
|
||
fill_opacity=0.3
|
||
)
|
||
ffn_box.move_to(ln2_box.get_bottom() + DOWN * 1.0)
|
||
ffn_text = Text(
|
||
"前馈神经网络\n(Feed Forward Network)",
|
||
font_size=16,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
ffn_text.move_to(ffn_box)
|
||
|
||
# 残差连接2 - 从层归一化2输入到前馈网络输出(右侧垂直箭头)
|
||
# 使用相同的x坐标确保箭头垂直
|
||
right_side_x2 = ln2_box.get_right()[0] + 1.5
|
||
residual2_start = np.array([right_side_x2, ln2_box.get_top()[1] + 0.1, 0])
|
||
residual2_end = np.array([right_side_x2, ffn_box.get_bottom()[1] - 0.1, 0])
|
||
residual2 = Arrow(
|
||
residual2_start,
|
||
residual2_end,
|
||
color=RED,
|
||
buff=0.1
|
||
)
|
||
residual2_label = Text(
|
||
"残差连接",
|
||
font_size=14,
|
||
font="SimHei",
|
||
color=RED
|
||
)
|
||
residual2_label.next_to(residual2, RIGHT, buff=0.1)
|
||
|
||
# 输出箭头
|
||
output_arrow = Arrow(
|
||
encoder_layer.get_bottom(),
|
||
encoder_layer.get_bottom() + DOWN * 0.5,
|
||
color=GREEN
|
||
)
|
||
output_label = Text(
|
||
"输出",
|
||
font_size=18,
|
||
font="SimHei",
|
||
color=GREEN
|
||
)
|
||
output_label.next_to(output_arrow, DOWN, buff=0.1)
|
||
|
||
# 连接箭头
|
||
arrow_ln1_mha = Arrow(
|
||
ln1_box.get_bottom(),
|
||
mha_box.get_top(),
|
||
buff=0.1,
|
||
color=WHITE
|
||
)
|
||
|
||
arrow_mha_ln2 = Arrow(
|
||
mha_box.get_bottom(),
|
||
ln2_box.get_top(),
|
||
buff=0.1,
|
||
color=WHITE
|
||
)
|
||
|
||
arrow_ln2_ffn = Arrow(
|
||
ln2_box.get_bottom(),
|
||
ffn_box.get_top(),
|
||
buff=0.1,
|
||
color=WHITE
|
||
)
|
||
|
||
# 显示所有组件
|
||
components = [
|
||
encoder_layer, input_arrow, input_label,
|
||
ln1_box, ln1_text, arrow_ln1_mha,
|
||
mha_box, mha_text, residual1, residual1_label,
|
||
arrow_mha_ln2, ln2_box, ln2_text,
|
||
arrow_ln2_ffn, ffn_box, ffn_text,
|
||
residual2, residual2_label, output_arrow, output_label
|
||
]
|
||
|
||
for comp in components:
|
||
self.play(Create(comp) if not isinstance(comp, Text) else Write(comp),
|
||
run_time=0.3)
|
||
|
||
# 添加说明
|
||
explanation = Text(
|
||
"编码器层 = 层归一化 + 多头注意力 + 前馈网络(均有残差连接)",
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=YELLOW
|
||
)
|
||
explanation.to_edge(DOWN, buff=1.0)
|
||
|
||
self.play(Write(explanation), run_time=1.0)
|
||
self.wait(2.0)
|
||
|
||
self.encoder_title = encoder_title
|
||
self.encoder_components = components + [explanation]
|
||
|
||
def show_decoder_details(self):
|
||
"""显示解码器层细节"""
|
||
# 清除编码器
|
||
self.play(
|
||
FadeOut(self.encoder_title),
|
||
*[FadeOut(c) for c in self.encoder_components],
|
||
run_time=0.8
|
||
)
|
||
|
||
# 解码器详细结构
|
||
decoder_title = Text(
|
||
"解码器层内部结构",
|
||
font_size=36,
|
||
font="SimHei",
|
||
color=ORANGE
|
||
)
|
||
decoder_title.to_edge(UP, buff=0.5)
|
||
self.play(Write(decoder_title), run_time=0.8)
|
||
|
||
# 解码器层框
|
||
decoder_layer = Rectangle(
|
||
width=6.0, height=8.0,
|
||
color=ORANGE,
|
||
fill_opacity=0.1
|
||
)
|
||
decoder_layer.center().shift(UP * 0.5)
|
||
|
||
# 输入箭头
|
||
input_arrow = Arrow(
|
||
decoder_layer.get_top() + UP * 0.5,
|
||
decoder_layer.get_top(),
|
||
color=GREEN
|
||
)
|
||
input_label = Text(
|
||
"输入",
|
||
font_size=18,
|
||
font="SimHei",
|
||
color=GREEN
|
||
)
|
||
input_label.next_to(input_arrow, UP, buff=0.1)
|
||
|
||
# 掩码多头注意力
|
||
masked_mha_box = Rectangle(
|
||
width=4.0, height=1.2,
|
||
color=YELLOW,
|
||
fill_opacity=0.3
|
||
)
|
||
masked_mha_box.move_to(decoder_layer.get_top() + DOWN * 1.5)
|
||
masked_mha_text = Text(
|
||
"掩码多头注意力\n(Masked Multi-Head Attention)",
|
||
font_size=16,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
masked_mha_text.move_to(masked_mha_box)
|
||
|
||
# 编码器-解码器注意力
|
||
enc_dec_box = Rectangle(
|
||
width=4.0, height=1.2,
|
||
color=PURPLE,
|
||
fill_opacity=0.3
|
||
)
|
||
enc_dec_box.move_to(masked_mha_box.get_bottom() + DOWN * 1.5)
|
||
enc_dec_text = Text(
|
||
"编码器-解码器注意力\n(Encoder-Decoder Attention)",
|
||
font_size=16,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
enc_dec_text.move_to(enc_dec_box)
|
||
|
||
# 前馈网络
|
||
ffn_box = Rectangle(
|
||
width=4.0, height=1.2,
|
||
color=GREEN,
|
||
fill_opacity=0.3
|
||
)
|
||
ffn_box.move_to(enc_dec_box.get_bottom() + DOWN * 1.5)
|
||
ffn_text = Text(
|
||
"前馈神经网络\n(Feed Forward Network)",
|
||
font_size=16,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
ffn_text.move_to(ffn_box)
|
||
|
||
# 层归一化(三个)
|
||
ln_positions = [
|
||
masked_mha_box.get_top() + DOWN * 0.2,
|
||
enc_dec_box.get_top() + DOWN * 0.2,
|
||
ffn_box.get_top() + DOWN * 0.2
|
||
]
|
||
|
||
ln_boxes = []
|
||
for i, pos in enumerate(ln_positions):
|
||
ln_box = Rectangle(
|
||
width=1.0, height=0.5,
|
||
color=BLUE,
|
||
fill_opacity=0.3
|
||
)
|
||
ln_box.move_to(pos + LEFT * 2.5)
|
||
ln_text = Text(
|
||
"LN",
|
||
font_size=12,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
ln_text.move_to(ln_box)
|
||
ln_boxes.append(VGroup(ln_box, ln_text))
|
||
|
||
# 残差连接
|
||
residual_arrows = []
|
||
for i, ln_box in enumerate(ln_boxes):
|
||
arrow = Arrow(
|
||
ln_box[0].get_left() + LEFT * 0.3,
|
||
[masked_mha_box, enc_dec_box, ffn_box][i].get_right() + RIGHT * 0.3,
|
||
color=RED,
|
||
buff=0.1
|
||
)
|
||
residual_arrows.append(arrow)
|
||
|
||
# 输出箭头
|
||
output_arrow = Arrow(
|
||
decoder_layer.get_bottom(),
|
||
decoder_layer.get_bottom() + DOWN * 0.5,
|
||
color=GREEN
|
||
)
|
||
output_label = Text(
|
||
"输出",
|
||
font_size=18,
|
||
font="SimHei",
|
||
color=GREEN
|
||
)
|
||
output_label.next_to(output_arrow, DOWN, buff=0.1)
|
||
|
||
# 连接箭头
|
||
arrows = []
|
||
arrow1 = Arrow(
|
||
input_arrow.get_end(),
|
||
ln_boxes[0][0].get_top(),
|
||
buff=0.1,
|
||
color=WHITE
|
||
)
|
||
|
||
arrow2 = Arrow(
|
||
masked_mha_box.get_bottom(),
|
||
ln_boxes[1][0].get_top(),
|
||
buff=0.1,
|
||
color=WHITE
|
||
)
|
||
|
||
arrow3 = Arrow(
|
||
enc_dec_box.get_bottom(),
|
||
ln_boxes[2][0].get_top(),
|
||
buff=0.1,
|
||
color=WHITE
|
||
)
|
||
|
||
arrow4 = Arrow(
|
||
ffn_box.get_bottom(),
|
||
output_arrow.get_start(),
|
||
buff=0.1,
|
||
color=WHITE
|
||
)
|
||
|
||
# 显示所有组件
|
||
components = [
|
||
decoder_layer, input_arrow, input_label, arrow1,
|
||
masked_mha_box, masked_mha_text,
|
||
enc_dec_box, enc_dec_text, ffn_box, ffn_text,
|
||
output_arrow, output_label
|
||
]
|
||
|
||
for ln in ln_boxes:
|
||
components.append(ln)
|
||
|
||
for arrow in [arrow2, arrow3, arrow4] + residual_arrows:
|
||
components.append(arrow)
|
||
|
||
for comp in components:
|
||
self.play(Create(comp) if not isinstance(comp, Text) and not isinstance(comp, VGroup) else Write(comp),
|
||
run_time=0.2)
|
||
|
||
# 添加说明
|
||
explanation = Text(
|
||
"解码器层 = 掩码注意力 + 编码器-解码器注意力 + 前馈网络",
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=YELLOW
|
||
)
|
||
explanation.to_edge(DOWN, buff=1.0)
|
||
|
||
self.play(Write(explanation), run_time=1.0)
|
||
self.wait(2.0)
|
||
|
||
self.decoder_title = decoder_title
|
||
self.decoder_components = components + [explanation]
|
||
|
||
def show_attention_mechanism(self):
|
||
"""显示注意力机制"""
|
||
# 清除解码器
|
||
self.play(
|
||
FadeOut(self.decoder_title),
|
||
*[FadeOut(c) for c in self.decoder_components],
|
||
run_time=0.8
|
||
)
|
||
|
||
# 注意力机制标题
|
||
attn_title = Text(
|
||
"注意力机制 (Attention)",
|
||
font_size=36,
|
||
font="SimHei",
|
||
color=YELLOW
|
||
)
|
||
attn_title.to_edge(UP, buff=0.5)
|
||
self.play(Write(attn_title), run_time=0.8)
|
||
|
||
# 注意力公式
|
||
formula = MathTex(
|
||
r"\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V",
|
||
font_size=32
|
||
)
|
||
formula.shift(UP * 1.0)
|
||
|
||
self.play(Write(formula), run_time=1.5)
|
||
self.wait(1.0)
|
||
|
||
# 多头注意力公式
|
||
multi_head_formula = MathTex(
|
||
r"\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O",
|
||
font_size=28
|
||
)
|
||
multi_head_formula.shift(DOWN * 0.5)
|
||
|
||
head_formula = MathTex(
|
||
r"\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)",
|
||
font_size=24
|
||
)
|
||
head_formula.shift(DOWN * 1.5)
|
||
|
||
self.play(Write(multi_head_formula), run_time=1.0)
|
||
self.play(Write(head_formula), run_time=1.0)
|
||
|
||
# 注意力可视化
|
||
explanation = Text(
|
||
"注意力允许模型关注输入序列中的所有位置",
|
||
font_size=22,
|
||
font="SimHei",
|
||
color=BLUE
|
||
)
|
||
explanation.to_edge(DOWN, buff=1.0)
|
||
|
||
self.play(Write(explanation), run_time=1.0)
|
||
self.wait(2.0)
|
||
|
||
self.attention_components = [attn_title, formula, multi_head_formula, head_formula, explanation]
|
||
|
||
def show_attention_formula(self):
|
||
"""显示注意力公式的详细解释"""
|
||
# 清除之前的注意力
|
||
self.play(
|
||
*[FadeOut(c) for c in self.attention_components],
|
||
run_time=0.8
|
||
)
|
||
|
||
# 公式详细解释
|
||
formula_title = Text(
|
||
"缩放点积注意力详解",
|
||
font_size=36,
|
||
font="SimHei",
|
||
color=PURPLE
|
||
)
|
||
formula_title.to_edge(UP, buff=0.5)
|
||
self.play(Write(formula_title), run_time=0.8)
|
||
|
||
# 分步公式 - 使用英文避免LaTeX编译错误
|
||
step1_formula = MathTex(
|
||
r"1.\ QK^T",
|
||
font_size=26
|
||
)
|
||
step1_formula.shift(UP * 2.5)
|
||
|
||
step1_text = Text(
|
||
"(计算相似度)",
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
step1_text.next_to(step1_formula, DOWN, buff=0.1)
|
||
|
||
step2_formula = MathTex(
|
||
r"2.\ \frac{QK^T}{\sqrt{d_k}}",
|
||
font_size=26
|
||
)
|
||
step2_formula.shift(UP * 1.0)
|
||
|
||
step2_text = Text(
|
||
"(缩放,稳定梯度)",
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
step2_text.next_to(step2_formula, DOWN, buff=0.1)
|
||
|
||
step3_formula = MathTex(
|
||
r"3.\ \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)",
|
||
font_size=26
|
||
)
|
||
step3_formula.shift(DOWN * 0.5)
|
||
|
||
step3_text = Text(
|
||
"(归一化为概率)",
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
step3_text.next_to(step3_formula, DOWN, buff=0.1)
|
||
|
||
step4_formula = MathTex(
|
||
r"4.\ \text{softmax}(\cdots)V",
|
||
font_size=26
|
||
)
|
||
step4_formula.shift(DOWN * 2.0)
|
||
|
||
step4_text = Text(
|
||
"(加权求和)",
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
step4_text.next_to(step4_formula, DOWN, buff=0.1)
|
||
|
||
formula_steps = [step1_formula, step2_formula, step3_formula, step4_formula]
|
||
text_steps = [step1_text, step2_text, step3_text, step4_text]
|
||
|
||
for i in range(4):
|
||
self.play(Write(formula_steps[i]), run_time=0.8)
|
||
self.play(Write(text_steps[i]), run_time=0.5)
|
||
self.wait(0.3)
|
||
|
||
# AstrAI中的注意力优化
|
||
astrai_title = Text(
|
||
"AstrAI中的注意力优化",
|
||
font_size=28,
|
||
font="SimHei",
|
||
color=GREEN
|
||
)
|
||
astrai_title.shift(DOWN * 2.5)
|
||
|
||
optimizations = [
|
||
"• Flash Attention:内存高效实现",
|
||
"• KV Cache:避免重复计算",
|
||
"• 连续批处理:动态请求调度",
|
||
"• 前缀缓存:Radix Tree管理",
|
||
]
|
||
|
||
opt_texts = []
|
||
for i, opt in enumerate(optimizations):
|
||
text = Text(
|
||
opt,
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=YELLOW
|
||
)
|
||
text.next_to(astrai_title, DOWN, buff=0.3 + i * 0.4)
|
||
text.align_to(astrai_title, LEFT)
|
||
opt_texts.append(text)
|
||
|
||
self.play(Write(astrai_title), run_time=0.8)
|
||
for text in opt_texts:
|
||
self.play(Write(text), run_time=0.5)
|
||
|
||
self.wait(3.0)
|
||
|
||
|
||
# 渲染命令
|
||
# python -m manim transformer_visualization.py TransformerArchitecture -pqh
|
||
|
||
if __name__ == "__main__":
|
||
# 快速测试
|
||
scene = TransformerArchitecture()
|
||
scene.render() |