diff --git a/transformer_visualization.py b/transformer_visualization.py new file mode 100644 index 0000000..2bbfae6 --- /dev/null +++ b/transformer_visualization.py @@ -0,0 +1,795 @@ +""" +Transformer 架构可视化 +使用Manim绘制Transformer的详细结构图 +""" +from manim import * +import numpy as np + + +class TransformerArchitecture(Scene): + """Transformer完整架构可视化""" + + def construct(self): + self.show_title() + self.show_overall_architecture() + self.show_encoder_details() + self.show_decoder_details() + self.show_attention_mechanism() + self.show_attention_formula() + + def show_title(self): + """显示标题""" + title = Text( + "Transformer 架构详解", + font_size=48, + font="SimHei", + color=WHITE + ) + title.to_edge(UP, buff=0.5) + + subtitle = Text( + "现代大语言模型的核心基础", + font_size=24, + font="SimHei", + color=GRAY + ) + subtitle.next_to(title, DOWN, buff=0.3) + + self.play(Write(title), run_time=1.0) + self.play(Write(subtitle), run_time=0.8) + self.wait(1.0) + + self.title = title + self.subtitle = subtitle + + def show_overall_architecture(self): + """显示整体架构图""" + # 清除标题 + self.play( + FadeOut(self.title), + FadeOut(self.subtitle), + run_time=0.5 + ) + + # 整体架构标题 + arch_title = Text( + "Transformer 整体架构", + font_size=36, + font="SimHei", + color=BLUE + ) + arch_title.to_edge(UP, buff=0.5) + self.play(Write(arch_title), run_time=0.8) + + # 创建架构图的主要组件 + # 输入部分 + input_box = Rectangle( + width=3.0, height=1.0, + color=GREEN, + fill_opacity=0.3 + ) + input_box.to_edge(LEFT, buff=2.0).shift(UP * 0.5) + + input_text = Text( + "输入序列", + font_size=20, + font="SimHei", + color=WHITE + ) + input_text.move_to(input_box) + + # 编码器堆叠 + encoder_stack = VGroup() + for i in range(6): + encoder = Rectangle( + width=3.0, height=0.5, + color=BLUE_C, + fill_opacity=0.2 + ) + encoder.shift(RIGHT * 2.5 + UP * (1.5 - i * 0.55)) + encoder_label = Text( + f"编码器层 {i+1}", + font_size=14, + font="SimHei" + ) + encoder_label.move_to(encoder) + encoder_group = VGroup(encoder, encoder_label) + encoder_stack.add(encoder_group) + + # 解码器堆叠 + decoder_stack = VGroup() + for i in range(6): + decoder = Rectangle( + width=3.0, height=0.5, + color=ORANGE, + fill_opacity=0.2 + ) + decoder.shift(RIGHT * 6.5 + UP * (1.5 - i * 0.55)) + decoder_label = Text( + f"解码器层 {i+1}", + font_size=14, + font="SimHei" + ) + decoder_label.move_to(decoder) + decoder_group = VGroup(decoder, decoder_label) + decoder_stack.add(decoder_group) + + # 输出部分 + output_box = Rectangle( + width=3.0, height=1.0, + color=RED, + fill_opacity=0.3 + ) + output_box.to_edge(RIGHT, buff=2.0).shift(UP * 0.5) + + output_text = Text( + "输出序列", + font_size=20, + font="SimHei", + color=WHITE + ) + output_text.move_to(output_box) + + # 箭头连接 + arrow1 = Arrow( + input_box.get_right(), + encoder_stack.get_left(), + buff=0.2, + color=YELLOW + ) + + arrow2 = Arrow( + encoder_stack.get_right(), + decoder_stack.get_left(), + buff=0.2, + color=YELLOW + ) + + arrow3 = Arrow( + decoder_stack.get_right(), + output_box.get_left(), + buff=0.2, + color=YELLOW + ) + + # 显示所有组件 + self.play( + Create(input_box), + Write(input_text), + run_time=0.8 + ) + self.wait(0.3) + + # 显示编码器 + for i, encoder in enumerate(encoder_stack): + self.play( + Create(encoder), + run_time=0.2 + ) + if i == 0: + self.play(Create(arrow1), run_time=0.5) + + self.wait(0.3) + + # 显示解码器 + for i, decoder in enumerate(decoder_stack): + self.play( + Create(decoder), + run_time=0.2 + ) + if i == 0: + self.play(Create(arrow2), run_time=0.5) + + self.wait(0.3) + + # 显示输出 + self.play( + Create(output_box), + Write(output_text), + Create(arrow3), + run_time=0.8 + ) + + # 添加说明 + explanation = Text( + "Transformer = 编码器(N层) + 解码器(N层)", + font_size=22, + font="SimHei", + color=YELLOW + ) + explanation.to_edge(DOWN, buff=1.0) + + self.play(Write(explanation), run_time=1.0) + self.wait(2.0) + + # 保存引用 + self.arch_title = arch_title + self.arch_components = [input_box, input_text, encoder_stack, decoder_stack, + output_box, output_text, arrow1, arrow2, arrow3, explanation] + + def show_encoder_details(self): + """显示编码器层细节""" + # 清除整体架构 + self.play( + FadeOut(self.arch_title), + *[FadeOut(c) for c in self.arch_components], + run_time=0.8 + ) + + # 编码器详细结构 + encoder_title = Text( + "编码器层内部结构", + font_size=36, + font="SimHei", + color=BLUE + ) + encoder_title.to_edge(UP, buff=0.5) + self.play(Write(encoder_title), run_time=0.8) + + # 编码器层框 + encoder_layer = Rectangle( + width=5.0, height=6.0, + color=BLUE, + fill_opacity=0.1 + ) + encoder_layer.center() + + # 输入箭头 + input_arrow = Arrow( + encoder_layer.get_top() + UP * 0.5, + encoder_layer.get_top(), + color=GREEN + ) + input_label = Text( + "输入", + font_size=18, + font="SimHei", + color=GREEN + ) + input_label.next_to(input_arrow, UP, buff=0.1) + + # 层归一化1 + ln1_box = Rectangle( + width=4.0, height=0.8, + color=PURPLE, + fill_opacity=0.3 + ) + ln1_box.move_to(encoder_layer.get_top() + DOWN * 1.0) + ln1_text = Text( + "层归一化 (LayerNorm)", + font_size=16, + font="SimHei", + color=WHITE + ) + ln1_text.move_to(ln1_box) + + # 多头注意力 + mha_box = Rectangle( + width=4.0, height=1.2, + color=YELLOW, + fill_opacity=0.3 + ) + mha_box.move_to(ln1_box.get_bottom() + DOWN * 1.0) + mha_text = Text( + "多头自注意力\n(Multi-Head Attention)", + font_size=16, + font="SimHei", + color=WHITE + ) + mha_text.move_to(mha_box) + + # 残差连接1 - 从输入到多头注意力输出(右侧垂直箭头) + # 使用相同的x坐标确保箭头垂直 + right_side_x = ln1_box.get_right()[0] + 1.5 + residual1_start = np.array([right_side_x, ln1_box.get_top()[1] + 0.1, 0]) + residual1_end = np.array([right_side_x, mha_box.get_bottom()[1] - 0.1, 0]) + residual1 = Arrow( + residual1_start, + residual1_end, + color=RED, + buff=0.1 + ) + residual1_label = Text( + "残差连接", + font_size=14, + font="SimHei", + color=RED + ) + residual1_label.next_to(residual1, RIGHT, buff=0.1) + + # 层归一化2 + ln2_box = Rectangle( + width=4.0, height=0.8, + color=PURPLE, + fill_opacity=0.3 + ) + ln2_box.move_to(mha_box.get_bottom() + DOWN * 1.5) + ln2_text = Text( + "层归一化 (LayerNorm)", + font_size=16, + font="SimHei", + color=WHITE + ) + ln2_text.move_to(ln2_box) + + # 前馈网络 + ffn_box = Rectangle( + width=4.0, height=1.2, + color=GREEN, + fill_opacity=0.3 + ) + ffn_box.move_to(ln2_box.get_bottom() + DOWN * 1.0) + ffn_text = Text( + "前馈神经网络\n(Feed Forward Network)", + font_size=16, + font="SimHei", + color=WHITE + ) + ffn_text.move_to(ffn_box) + + # 残差连接2 - 从层归一化2输入到前馈网络输出(右侧垂直箭头) + # 使用相同的x坐标确保箭头垂直 + right_side_x2 = ln2_box.get_right()[0] + 1.5 + residual2_start = np.array([right_side_x2, ln2_box.get_top()[1] + 0.1, 0]) + residual2_end = np.array([right_side_x2, ffn_box.get_bottom()[1] - 0.1, 0]) + residual2 = Arrow( + residual2_start, + residual2_end, + color=RED, + buff=0.1 + ) + residual2_label = Text( + "残差连接", + font_size=14, + font="SimHei", + color=RED + ) + residual2_label.next_to(residual2, RIGHT, buff=0.1) + + # 输出箭头 + output_arrow = Arrow( + encoder_layer.get_bottom(), + encoder_layer.get_bottom() + DOWN * 0.5, + color=GREEN + ) + output_label = Text( + "输出", + font_size=18, + font="SimHei", + color=GREEN + ) + output_label.next_to(output_arrow, DOWN, buff=0.1) + + # 连接箭头 + arrow_ln1_mha = Arrow( + ln1_box.get_bottom(), + mha_box.get_top(), + buff=0.1, + color=WHITE + ) + + arrow_mha_ln2 = Arrow( + mha_box.get_bottom(), + ln2_box.get_top(), + buff=0.1, + color=WHITE + ) + + arrow_ln2_ffn = Arrow( + ln2_box.get_bottom(), + ffn_box.get_top(), + buff=0.1, + color=WHITE + ) + + # 显示所有组件 + components = [ + encoder_layer, input_arrow, input_label, + ln1_box, ln1_text, arrow_ln1_mha, + mha_box, mha_text, residual1, residual1_label, + arrow_mha_ln2, ln2_box, ln2_text, + arrow_ln2_ffn, ffn_box, ffn_text, + residual2, residual2_label, output_arrow, output_label + ] + + for comp in components: + self.play(Create(comp) if not isinstance(comp, Text) else Write(comp), + run_time=0.3) + + # 添加说明 + explanation = Text( + "编码器层 = 层归一化 + 多头注意力 + 前馈网络(均有残差连接)", + font_size=20, + font="SimHei", + color=YELLOW + ) + explanation.to_edge(DOWN, buff=1.0) + + self.play(Write(explanation), run_time=1.0) + self.wait(2.0) + + self.encoder_title = encoder_title + self.encoder_components = components + [explanation] + + def show_decoder_details(self): + """显示解码器层细节""" + # 清除编码器 + self.play( + FadeOut(self.encoder_title), + *[FadeOut(c) for c in self.encoder_components], + run_time=0.8 + ) + + # 解码器详细结构 + decoder_title = Text( + "解码器层内部结构", + font_size=36, + font="SimHei", + color=ORANGE + ) + decoder_title.to_edge(UP, buff=0.5) + self.play(Write(decoder_title), run_time=0.8) + + # 解码器层框 + decoder_layer = Rectangle( + width=6.0, height=8.0, + color=ORANGE, + fill_opacity=0.1 + ) + decoder_layer.center().shift(UP * 0.5) + + # 输入箭头 + input_arrow = Arrow( + decoder_layer.get_top() + UP * 0.5, + decoder_layer.get_top(), + color=GREEN + ) + input_label = Text( + "输入", + font_size=18, + font="SimHei", + color=GREEN + ) + input_label.next_to(input_arrow, UP, buff=0.1) + + # 掩码多头注意力 + masked_mha_box = Rectangle( + width=4.0, height=1.2, + color=YELLOW, + fill_opacity=0.3 + ) + masked_mha_box.move_to(decoder_layer.get_top() + DOWN * 1.5) + masked_mha_text = Text( + "掩码多头注意力\n(Masked Multi-Head Attention)", + font_size=16, + font="SimHei", + color=WHITE + ) + masked_mha_text.move_to(masked_mha_box) + + # 编码器-解码器注意力 + enc_dec_box = Rectangle( + width=4.0, height=1.2, + color=PURPLE, + fill_opacity=0.3 + ) + enc_dec_box.move_to(masked_mha_box.get_bottom() + DOWN * 1.5) + enc_dec_text = Text( + "编码器-解码器注意力\n(Encoder-Decoder Attention)", + font_size=16, + font="SimHei", + color=WHITE + ) + enc_dec_text.move_to(enc_dec_box) + + # 前馈网络 + ffn_box = Rectangle( + width=4.0, height=1.2, + color=GREEN, + fill_opacity=0.3 + ) + ffn_box.move_to(enc_dec_box.get_bottom() + DOWN * 1.5) + ffn_text = Text( + "前馈神经网络\n(Feed Forward Network)", + font_size=16, + font="SimHei", + color=WHITE + ) + ffn_text.move_to(ffn_box) + + # 层归一化(三个) + ln_positions = [ + masked_mha_box.get_top() + DOWN * 0.2, + enc_dec_box.get_top() + DOWN * 0.2, + ffn_box.get_top() + DOWN * 0.2 + ] + + ln_boxes = [] + for i, pos in enumerate(ln_positions): + ln_box = Rectangle( + width=1.0, height=0.5, + color=BLUE, + fill_opacity=0.3 + ) + ln_box.move_to(pos + LEFT * 2.5) + ln_text = Text( + "LN", + font_size=12, + font="SimHei", + color=WHITE + ) + ln_text.move_to(ln_box) + ln_boxes.append(VGroup(ln_box, ln_text)) + + # 残差连接 + residual_arrows = [] + for i, ln_box in enumerate(ln_boxes): + arrow = Arrow( + ln_box[0].get_left() + LEFT * 0.3, + [masked_mha_box, enc_dec_box, ffn_box][i].get_right() + RIGHT * 0.3, + color=RED, + buff=0.1 + ) + residual_arrows.append(arrow) + + # 输出箭头 + output_arrow = Arrow( + decoder_layer.get_bottom(), + decoder_layer.get_bottom() + DOWN * 0.5, + color=GREEN + ) + output_label = Text( + "输出", + font_size=18, + font="SimHei", + color=GREEN + ) + output_label.next_to(output_arrow, DOWN, buff=0.1) + + # 连接箭头 + arrows = [] + arrow1 = Arrow( + input_arrow.get_end(), + ln_boxes[0][0].get_top(), + buff=0.1, + color=WHITE + ) + + arrow2 = Arrow( + masked_mha_box.get_bottom(), + ln_boxes[1][0].get_top(), + buff=0.1, + color=WHITE + ) + + arrow3 = Arrow( + enc_dec_box.get_bottom(), + ln_boxes[2][0].get_top(), + buff=0.1, + color=WHITE + ) + + arrow4 = Arrow( + ffn_box.get_bottom(), + output_arrow.get_start(), + buff=0.1, + color=WHITE + ) + + # 显示所有组件 + components = [ + decoder_layer, input_arrow, input_label, arrow1, + masked_mha_box, masked_mha_text, + enc_dec_box, enc_dec_text, ffn_box, ffn_text, + output_arrow, output_label + ] + + for ln in ln_boxes: + components.append(ln) + + for arrow in [arrow2, arrow3, arrow4] + residual_arrows: + components.append(arrow) + + for comp in components: + self.play(Create(comp) if not isinstance(comp, Text) and not isinstance(comp, VGroup) else Write(comp), + run_time=0.2) + + # 添加说明 + explanation = Text( + "解码器层 = 掩码注意力 + 编码器-解码器注意力 + 前馈网络", + font_size=20, + font="SimHei", + color=YELLOW + ) + explanation.to_edge(DOWN, buff=1.0) + + self.play(Write(explanation), run_time=1.0) + self.wait(2.0) + + self.decoder_title = decoder_title + self.decoder_components = components + [explanation] + + def show_attention_mechanism(self): + """显示注意力机制""" + # 清除解码器 + self.play( + FadeOut(self.decoder_title), + *[FadeOut(c) for c in self.decoder_components], + run_time=0.8 + ) + + # 注意力机制标题 + attn_title = Text( + "注意力机制 (Attention)", + font_size=36, + font="SimHei", + color=YELLOW + ) + attn_title.to_edge(UP, buff=0.5) + self.play(Write(attn_title), run_time=0.8) + + # 注意力公式 + formula = MathTex( + r"\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V", + font_size=32 + ) + formula.shift(UP * 1.0) + + self.play(Write(formula), run_time=1.5) + self.wait(1.0) + + # 多头注意力公式 + multi_head_formula = MathTex( + r"\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O", + font_size=28 + ) + multi_head_formula.shift(DOWN * 0.5) + + head_formula = MathTex( + r"\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)", + font_size=24 + ) + head_formula.shift(DOWN * 1.5) + + self.play(Write(multi_head_formula), run_time=1.0) + self.play(Write(head_formula), run_time=1.0) + + # 注意力可视化 + explanation = Text( + "注意力允许模型关注输入序列中的所有位置", + font_size=22, + font="SimHei", + color=BLUE + ) + explanation.to_edge(DOWN, buff=1.0) + + self.play(Write(explanation), run_time=1.0) + self.wait(2.0) + + self.attention_components = [attn_title, formula, multi_head_formula, head_formula, explanation] + + def show_attention_formula(self): + """显示注意力公式的详细解释""" + # 清除之前的注意力 + self.play( + *[FadeOut(c) for c in self.attention_components], + run_time=0.8 + ) + + # 公式详细解释 + formula_title = Text( + "缩放点积注意力详解", + font_size=36, + font="SimHei", + color=PURPLE + ) + formula_title.to_edge(UP, buff=0.5) + self.play(Write(formula_title), run_time=0.8) + + # 分步公式 - 使用英文避免LaTeX编译错误 + step1_formula = MathTex( + r"1.\ QK^T", + font_size=26 + ) + step1_formula.shift(UP * 2.5) + + step1_text = Text( + "(计算相似度)", + font_size=20, + font="SimHei", + color=GRAY + ) + step1_text.next_to(step1_formula, DOWN, buff=0.1) + + step2_formula = MathTex( + r"2.\ \frac{QK^T}{\sqrt{d_k}}", + font_size=26 + ) + step2_formula.shift(UP * 1.0) + + step2_text = Text( + "(缩放,稳定梯度)", + font_size=20, + font="SimHei", + color=GRAY + ) + step2_text.next_to(step2_formula, DOWN, buff=0.1) + + step3_formula = MathTex( + r"3.\ \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)", + font_size=26 + ) + step3_formula.shift(DOWN * 0.5) + + step3_text = Text( + "(归一化为概率)", + font_size=20, + font="SimHei", + color=GRAY + ) + step3_text.next_to(step3_formula, DOWN, buff=0.1) + + step4_formula = MathTex( + r"4.\ \text{softmax}(\cdots)V", + font_size=26 + ) + step4_formula.shift(DOWN * 2.0) + + step4_text = Text( + "(加权求和)", + font_size=20, + font="SimHei", + color=GRAY + ) + step4_text.next_to(step4_formula, DOWN, buff=0.1) + + formula_steps = [step1_formula, step2_formula, step3_formula, step4_formula] + text_steps = [step1_text, step2_text, step3_text, step4_text] + + for i in range(4): + self.play(Write(formula_steps[i]), run_time=0.8) + self.play(Write(text_steps[i]), run_time=0.5) + self.wait(0.3) + + # AstrAI中的注意力优化 + astrai_title = Text( + "AstrAI中的注意力优化", + font_size=28, + font="SimHei", + color=GREEN + ) + astrai_title.shift(DOWN * 2.5) + + optimizations = [ + "• Flash Attention:内存高效实现", + "• KV Cache:避免重复计算", + "• 连续批处理:动态请求调度", + "• 前缀缓存:Radix Tree管理", + ] + + opt_texts = [] + for i, opt in enumerate(optimizations): + text = Text( + opt, + font_size=20, + font="SimHei", + color=YELLOW + ) + text.next_to(astrai_title, DOWN, buff=0.3 + i * 0.4) + text.align_to(astrai_title, LEFT) + opt_texts.append(text) + + self.play(Write(astrai_title), run_time=0.8) + for text in opt_texts: + self.play(Write(text), run_time=0.5) + + self.wait(3.0) + + +# 渲染命令 +# python -m manim transformer_visualization.py TransformerArchitecture -pqh + +if __name__ == "__main__": + # 快速测试 + scene = TransformerArchitecture() + scene.render() \ No newline at end of file