""" Transformer 架构可视化 使用Manim绘制Transformer的详细结构图 """ from manim import * import numpy as np class TransformerArchitecture(Scene): """Transformer完整架构可视化""" def construct(self): self.show_title() self.show_overall_architecture() self.show_encoder_details() self.show_decoder_details() self.show_attention_mechanism() self.show_attention_formula() def show_title(self): """显示标题""" title = Text( "Transformer 架构详解", font_size=48, font="SimHei", color=WHITE ) title.to_edge(UP, buff=0.5) subtitle = Text( "现代大语言模型的核心基础", font_size=24, font="SimHei", color=GRAY ) subtitle.next_to(title, DOWN, buff=0.3) self.play(Write(title), run_time=1.0) self.play(Write(subtitle), run_time=0.8) self.wait(1.0) self.title = title self.subtitle = subtitle def show_overall_architecture(self): """显示整体架构图""" # 清除标题 self.play( FadeOut(self.title), FadeOut(self.subtitle), run_time=0.5 ) # 整体架构标题 arch_title = Text( "Transformer 整体架构", font_size=36, font="SimHei", color=BLUE ) arch_title.to_edge(UP, buff=0.5) self.play(Write(arch_title), run_time=0.8) # 创建架构图的主要组件 # 输入部分 input_box = Rectangle( width=3.0, height=1.0, color=GREEN, fill_opacity=0.3 ) input_box.to_edge(LEFT, buff=2.0).shift(UP * 0.5) input_text = Text( "输入序列", font_size=20, font="SimHei", color=WHITE ) input_text.move_to(input_box) # 编码器堆叠 encoder_stack = VGroup() for i in range(6): encoder = Rectangle( width=3.0, height=0.5, color=BLUE_C, fill_opacity=0.2 ) encoder.shift(RIGHT * 2.5 + UP * (1.5 - i * 0.55)) encoder_label = Text( f"编码器层 {i+1}", font_size=14, font="SimHei" ) encoder_label.move_to(encoder) encoder_group = VGroup(encoder, encoder_label) encoder_stack.add(encoder_group) # 解码器堆叠 decoder_stack = VGroup() for i in range(6): decoder = Rectangle( width=3.0, height=0.5, color=ORANGE, fill_opacity=0.2 ) decoder.shift(RIGHT * 6.5 + UP * (1.5 - i * 0.55)) decoder_label = Text( f"解码器层 {i+1}", font_size=14, font="SimHei" ) decoder_label.move_to(decoder) decoder_group = VGroup(decoder, decoder_label) decoder_stack.add(decoder_group) # 输出部分 output_box = Rectangle( width=3.0, height=1.0, color=RED, fill_opacity=0.3 ) output_box.to_edge(RIGHT, buff=2.0).shift(UP * 0.5) output_text = Text( "输出序列", font_size=20, font="SimHei", color=WHITE ) output_text.move_to(output_box) # 箭头连接 arrow1 = Arrow( input_box.get_right(), encoder_stack.get_left(), buff=0.2, color=YELLOW ) arrow2 = Arrow( encoder_stack.get_right(), decoder_stack.get_left(), buff=0.2, color=YELLOW ) arrow3 = Arrow( decoder_stack.get_right(), output_box.get_left(), buff=0.2, color=YELLOW ) # 显示所有组件 self.play( Create(input_box), Write(input_text), run_time=0.8 ) self.wait(0.3) # 显示编码器 for i, encoder in enumerate(encoder_stack): self.play( Create(encoder), run_time=0.2 ) if i == 0: self.play(Create(arrow1), run_time=0.5) self.wait(0.3) # 显示解码器 for i, decoder in enumerate(decoder_stack): self.play( Create(decoder), run_time=0.2 ) if i == 0: self.play(Create(arrow2), run_time=0.5) self.wait(0.3) # 显示输出 self.play( Create(output_box), Write(output_text), Create(arrow3), run_time=0.8 ) # 添加说明 explanation = Text( "Transformer = 编码器(N层) + 解码器(N层)", font_size=22, font="SimHei", color=YELLOW ) explanation.to_edge(DOWN, buff=1.0) self.play(Write(explanation), run_time=1.0) self.wait(2.0) # 保存引用 self.arch_title = arch_title self.arch_components = [input_box, input_text, encoder_stack, decoder_stack, output_box, output_text, arrow1, arrow2, arrow3, explanation] def show_encoder_details(self): """显示编码器层细节""" # 清除整体架构 self.play( FadeOut(self.arch_title), *[FadeOut(c) for c in self.arch_components], run_time=0.8 ) # 编码器详细结构 encoder_title = Text( "编码器层内部结构", font_size=36, font="SimHei", color=BLUE ) encoder_title.to_edge(UP, buff=0.5) self.play(Write(encoder_title), run_time=0.8) # 编码器层框 encoder_layer = Rectangle( width=5.0, height=6.0, color=BLUE, fill_opacity=0.1 ) encoder_layer.center() # 输入箭头 input_arrow = Arrow( encoder_layer.get_top() + UP * 0.5, encoder_layer.get_top(), color=GREEN ) input_label = Text( "输入", font_size=18, font="SimHei", color=GREEN ) input_label.next_to(input_arrow, UP, buff=0.1) # 层归一化1 ln1_box = Rectangle( width=4.0, height=0.8, color=PURPLE, fill_opacity=0.3 ) ln1_box.move_to(encoder_layer.get_top() + DOWN * 1.0) ln1_text = Text( "层归一化 (LayerNorm)", font_size=16, font="SimHei", color=WHITE ) ln1_text.move_to(ln1_box) # 多头注意力 mha_box = Rectangle( width=4.0, height=1.2, color=YELLOW, fill_opacity=0.3 ) mha_box.move_to(ln1_box.get_bottom() + DOWN * 1.0) mha_text = Text( "多头自注意力\n(Multi-Head Attention)", font_size=16, font="SimHei", color=WHITE ) mha_text.move_to(mha_box) # 残差连接1 - 从输入到多头注意力输出(右侧垂直箭头) # 使用相同的x坐标确保箭头垂直 right_side_x = ln1_box.get_right()[0] + 1.5 residual1_start = np.array([right_side_x, ln1_box.get_top()[1] + 0.1, 0]) residual1_end = np.array([right_side_x, mha_box.get_bottom()[1] - 0.1, 0]) residual1 = Arrow( residual1_start, residual1_end, color=RED, buff=0.1 ) residual1_label = Text( "残差连接", font_size=14, font="SimHei", color=RED ) residual1_label.next_to(residual1, RIGHT, buff=0.1) # 层归一化2 ln2_box = Rectangle( width=4.0, height=0.8, color=PURPLE, fill_opacity=0.3 ) ln2_box.move_to(mha_box.get_bottom() + DOWN * 1.5) ln2_text = Text( "层归一化 (LayerNorm)", font_size=16, font="SimHei", color=WHITE ) ln2_text.move_to(ln2_box) # 前馈网络 ffn_box = Rectangle( width=4.0, height=1.2, color=GREEN, fill_opacity=0.3 ) ffn_box.move_to(ln2_box.get_bottom() + DOWN * 1.0) ffn_text = Text( "前馈神经网络\n(Feed Forward Network)", font_size=16, font="SimHei", color=WHITE ) ffn_text.move_to(ffn_box) # 残差连接2 - 从层归一化2输入到前馈网络输出(右侧垂直箭头) # 使用相同的x坐标确保箭头垂直 right_side_x2 = ln2_box.get_right()[0] + 1.5 residual2_start = np.array([right_side_x2, ln2_box.get_top()[1] + 0.1, 0]) residual2_end = np.array([right_side_x2, ffn_box.get_bottom()[1] - 0.1, 0]) residual2 = Arrow( residual2_start, residual2_end, color=RED, buff=0.1 ) residual2_label = Text( "残差连接", font_size=14, font="SimHei", color=RED ) residual2_label.next_to(residual2, RIGHT, buff=0.1) # 输出箭头 output_arrow = Arrow( encoder_layer.get_bottom(), encoder_layer.get_bottom() + DOWN * 0.5, color=GREEN ) output_label = Text( "输出", font_size=18, font="SimHei", color=GREEN ) output_label.next_to(output_arrow, DOWN, buff=0.1) # 连接箭头 arrow_ln1_mha = Arrow( ln1_box.get_bottom(), mha_box.get_top(), buff=0.1, color=WHITE ) arrow_mha_ln2 = Arrow( mha_box.get_bottom(), ln2_box.get_top(), buff=0.1, color=WHITE ) arrow_ln2_ffn = Arrow( ln2_box.get_bottom(), ffn_box.get_top(), buff=0.1, color=WHITE ) # 显示所有组件 components = [ encoder_layer, input_arrow, input_label, ln1_box, ln1_text, arrow_ln1_mha, mha_box, mha_text, residual1, residual1_label, arrow_mha_ln2, ln2_box, ln2_text, arrow_ln2_ffn, ffn_box, ffn_text, residual2, residual2_label, output_arrow, output_label ] for comp in components: self.play(Create(comp) if not isinstance(comp, Text) else Write(comp), run_time=0.3) # 添加说明 explanation = Text( "编码器层 = 层归一化 + 多头注意力 + 前馈网络(均有残差连接)", font_size=20, font="SimHei", color=YELLOW ) explanation.to_edge(DOWN, buff=1.0) self.play(Write(explanation), run_time=1.0) self.wait(2.0) self.encoder_title = encoder_title self.encoder_components = components + [explanation] def show_decoder_details(self): """显示解码器层细节""" # 清除编码器 self.play( FadeOut(self.encoder_title), *[FadeOut(c) for c in self.encoder_components], run_time=0.8 ) # 解码器详细结构 decoder_title = Text( "解码器层内部结构", font_size=36, font="SimHei", color=ORANGE ) decoder_title.to_edge(UP, buff=0.5) self.play(Write(decoder_title), run_time=0.8) # 解码器层框 decoder_layer = Rectangle( width=6.0, height=8.0, color=ORANGE, fill_opacity=0.1 ) decoder_layer.center().shift(UP * 0.5) # 输入箭头 input_arrow = Arrow( decoder_layer.get_top() + UP * 0.5, decoder_layer.get_top(), color=GREEN ) input_label = Text( "输入", font_size=18, font="SimHei", color=GREEN ) input_label.next_to(input_arrow, UP, buff=0.1) # 掩码多头注意力 masked_mha_box = Rectangle( width=4.0, height=1.2, color=YELLOW, fill_opacity=0.3 ) masked_mha_box.move_to(decoder_layer.get_top() + DOWN * 1.5) masked_mha_text = Text( "掩码多头注意力\n(Masked Multi-Head Attention)", font_size=16, font="SimHei", color=WHITE ) masked_mha_text.move_to(masked_mha_box) # 编码器-解码器注意力 enc_dec_box = Rectangle( width=4.0, height=1.2, color=PURPLE, fill_opacity=0.3 ) enc_dec_box.move_to(masked_mha_box.get_bottom() + DOWN * 1.5) enc_dec_text = Text( "编码器-解码器注意力\n(Encoder-Decoder Attention)", font_size=16, font="SimHei", color=WHITE ) enc_dec_text.move_to(enc_dec_box) # 前馈网络 ffn_box = Rectangle( width=4.0, height=1.2, color=GREEN, fill_opacity=0.3 ) ffn_box.move_to(enc_dec_box.get_bottom() + DOWN * 1.5) ffn_text = Text( "前馈神经网络\n(Feed Forward Network)", font_size=16, font="SimHei", color=WHITE ) ffn_text.move_to(ffn_box) # 层归一化(三个) ln_positions = [ masked_mha_box.get_top() + DOWN * 0.2, enc_dec_box.get_top() + DOWN * 0.2, ffn_box.get_top() + DOWN * 0.2 ] ln_boxes = [] for i, pos in enumerate(ln_positions): ln_box = Rectangle( width=1.0, height=0.5, color=BLUE, fill_opacity=0.3 ) ln_box.move_to(pos + LEFT * 2.5) ln_text = Text( "LN", font_size=12, font="SimHei", color=WHITE ) ln_text.move_to(ln_box) ln_boxes.append(VGroup(ln_box, ln_text)) # 残差连接 residual_arrows = [] for i, ln_box in enumerate(ln_boxes): arrow = Arrow( ln_box[0].get_left() + LEFT * 0.3, [masked_mha_box, enc_dec_box, ffn_box][i].get_right() + RIGHT * 0.3, color=RED, buff=0.1 ) residual_arrows.append(arrow) # 输出箭头 output_arrow = Arrow( decoder_layer.get_bottom(), decoder_layer.get_bottom() + DOWN * 0.5, color=GREEN ) output_label = Text( "输出", font_size=18, font="SimHei", color=GREEN ) output_label.next_to(output_arrow, DOWN, buff=0.1) # 连接箭头 arrows = [] arrow1 = Arrow( input_arrow.get_end(), ln_boxes[0][0].get_top(), buff=0.1, color=WHITE ) arrow2 = Arrow( masked_mha_box.get_bottom(), ln_boxes[1][0].get_top(), buff=0.1, color=WHITE ) arrow3 = Arrow( enc_dec_box.get_bottom(), ln_boxes[2][0].get_top(), buff=0.1, color=WHITE ) arrow4 = Arrow( ffn_box.get_bottom(), output_arrow.get_start(), buff=0.1, color=WHITE ) # 显示所有组件 components = [ decoder_layer, input_arrow, input_label, arrow1, masked_mha_box, masked_mha_text, enc_dec_box, enc_dec_text, ffn_box, ffn_text, output_arrow, output_label ] for ln in ln_boxes: components.append(ln) for arrow in [arrow2, arrow3, arrow4] + residual_arrows: components.append(arrow) for comp in components: self.play(Create(comp) if not isinstance(comp, Text) and not isinstance(comp, VGroup) else Write(comp), run_time=0.2) # 添加说明 explanation = Text( "解码器层 = 掩码注意力 + 编码器-解码器注意力 + 前馈网络", font_size=20, font="SimHei", color=YELLOW ) explanation.to_edge(DOWN, buff=1.0) self.play(Write(explanation), run_time=1.0) self.wait(2.0) self.decoder_title = decoder_title self.decoder_components = components + [explanation] def show_attention_mechanism(self): """显示注意力机制""" # 清除解码器 self.play( FadeOut(self.decoder_title), *[FadeOut(c) for c in self.decoder_components], run_time=0.8 ) # 注意力机制标题 attn_title = Text( "注意力机制 (Attention)", font_size=36, font="SimHei", color=YELLOW ) attn_title.to_edge(UP, buff=0.5) self.play(Write(attn_title), run_time=0.8) # 注意力公式 formula = MathTex( r"\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V", font_size=32 ) formula.shift(UP * 1.0) self.play(Write(formula), run_time=1.5) self.wait(1.0) # 多头注意力公式 multi_head_formula = MathTex( r"\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O", font_size=28 ) multi_head_formula.shift(DOWN * 0.5) head_formula = MathTex( r"\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)", font_size=24 ) head_formula.shift(DOWN * 1.5) self.play(Write(multi_head_formula), run_time=1.0) self.play(Write(head_formula), run_time=1.0) # 注意力可视化 explanation = Text( "注意力允许模型关注输入序列中的所有位置", font_size=22, font="SimHei", color=BLUE ) explanation.to_edge(DOWN, buff=1.0) self.play(Write(explanation), run_time=1.0) self.wait(2.0) self.attention_components = [attn_title, formula, multi_head_formula, head_formula, explanation] def show_attention_formula(self): """显示注意力公式的详细解释""" # 清除之前的注意力 self.play( *[FadeOut(c) for c in self.attention_components], run_time=0.8 ) # 公式详细解释 formula_title = Text( "缩放点积注意力详解", font_size=36, font="SimHei", color=PURPLE ) formula_title.to_edge(UP, buff=0.5) self.play(Write(formula_title), run_time=0.8) # 分步公式 - 使用英文避免LaTeX编译错误 step1_formula = MathTex( r"1.\ QK^T", font_size=26 ) step1_formula.shift(UP * 2.5) step1_text = Text( "(计算相似度)", font_size=20, font="SimHei", color=GRAY ) step1_text.next_to(step1_formula, DOWN, buff=0.1) step2_formula = MathTex( r"2.\ \frac{QK^T}{\sqrt{d_k}}", font_size=26 ) step2_formula.shift(UP * 1.0) step2_text = Text( "(缩放,稳定梯度)", font_size=20, font="SimHei", color=GRAY ) step2_text.next_to(step2_formula, DOWN, buff=0.1) step3_formula = MathTex( r"3.\ \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)", font_size=26 ) step3_formula.shift(DOWN * 0.5) step3_text = Text( "(归一化为概率)", font_size=20, font="SimHei", color=GRAY ) step3_text.next_to(step3_formula, DOWN, buff=0.1) step4_formula = MathTex( r"4.\ \text{softmax}(\cdots)V", font_size=26 ) step4_formula.shift(DOWN * 2.0) step4_text = Text( "(加权求和)", font_size=20, font="SimHei", color=GRAY ) step4_text.next_to(step4_formula, DOWN, buff=0.1) formula_steps = [step1_formula, step2_formula, step3_formula, step4_formula] text_steps = [step1_text, step2_text, step3_text, step4_text] for i in range(4): self.play(Write(formula_steps[i]), run_time=0.8) self.play(Write(text_steps[i]), run_time=0.5) self.wait(0.3) # AstrAI中的注意力优化 astrai_title = Text( "AstrAI中的注意力优化", font_size=28, font="SimHei", color=GREEN ) astrai_title.shift(DOWN * 2.5) optimizations = [ "• Flash Attention:内存高效实现", "• KV Cache:避免重复计算", "• 连续批处理:动态请求调度", "• 前缀缓存:Radix Tree管理", ] opt_texts = [] for i, opt in enumerate(optimizations): text = Text( opt, font_size=20, font="SimHei", color=YELLOW ) text.next_to(astrai_title, DOWN, buff=0.3 + i * 0.4) text.align_to(astrai_title, LEFT) opt_texts.append(text) self.play(Write(astrai_title), run_time=0.8) for text in opt_texts: self.play(Write(text), run_time=0.5) self.wait(3.0) # 渲染命令 # python -m manim transformer_visualization.py TransformerArchitecture -pqh if __name__ == "__main__": # 快速测试 scene = TransformerArchitecture() scene.render()