""" AstrAI项目介绍视频 - Part 2: AstrAI的核心设计 """ from manim import * class Part2_CoreDesign(Scene): """第二部分:AstrAI的核心设计""" def construct(self): # 2.1 整体架构 self.play_architecture() # 2.2 推理流程 self.play_inference_flow() # 2.3 调度策略 self.play_scheduler() # 2.4 内存优化 self.play_memory_optimization() # 2.5 分布式支持 self.play_distributed_support() def play_architecture(self): """整体架构展示""" # 标题 title = Text( "AstrAI 核心设计", font_size=48, font="SimHei", color=WHITE ) title.to_edge(UP, buff=0.8) self.play(Write(title), run_time=1.0) self.wait(0.5) # 模块表格 modules = [ ("astrai.config", "配置管理"), ("astrai.dataset", "数据集加载"), ("astrai.model", "神经网络模型"), ("astrai.tokenize", "分词器和聊天模板"), ("astrai.trainer", "训练工作流管理"), ("astrai.inference", "推理调度"), ("astrai.parallel", "分布式并行支持"), ] # 创建模块列表 module_objects = [] for i, (module, desc) in enumerate(modules): module_text = Text( module, font_size=28, font="SimHei", color=BLUE ) desc_text = Text( desc, font_size=20, font="SimHei", color=GRAY ) module_text.shift(LEFT * 3) desc_text.next_to(module_text, RIGHT, buff=0.5) if i == 0: module_text.to_edge(UP, buff=2.5) else: module_text.next_to(module_objects[-1][0], DOWN, buff=0.4) desc_text.next_to(module_text, RIGHT, buff=0.5) self.play(Write(module_text), run_time=0.5) self.play(Write(desc_text), run_time=0.3) module_objects.append((module_text, desc_text)) if i < len(modules) - 1: self.wait(0.2) self.wait(2.0) # 淡出 self.play( FadeOut(title), *[FadeOut(m) for m, d in module_objects], *[FadeOut(d) for m, d in module_objects], run_time=0.8 ) self.wait(0.3) def play_inference_flow(self): """推理流程:Prefill → Decode""" title = Text( "推理流程:Prefill → Decode", font_size=40, font="SimHei", color=WHITE ) title.to_edge(UP, buff=0.8) self.play(Write(title), run_time=1.0) self.wait(0.3) # Pre-fill阶段 prefill_box = RoundedRectangle( width=3.5, height=1.5, corner_radius=0.2, color=BLUE, fill_opacity=0.3 ) prefill_box.to_edge(LEFT, buff=2.0) prefill_title = Text( "Pre-fill", font_size=28, font="SimHei", color=BLUE ) prefill_title.next_to(prefill_box, UP, buff=0.2) prefill_desc = Text( "一次性处理输入序列\n计算所有Token的K和V", font_size=16, font="SimHei", color=GRAY ) prefill_desc.move_to(prefill_box) self.play( Create(prefill_box), Write(prefill_title), Write(prefill_desc), run_time=1.0 ) self.wait(0.5) # Decode阶段 decode_box = RoundedRectangle( width=3.5, height=1.5, corner_radius=0.2, color=GREEN, fill_opacity=0.3 ) decode_box.to_edge(RIGHT, buff=2.0) decode_title = Text( "Decode", font_size=28, font="SimHei", color=GREEN ) decode_title.next_to(decode_box, UP, buff=0.2) decode_desc = Text( "生成新Token\n从KV Cache读取K和V", font_size=16, font="SimHei", color=GRAY ) decode_desc.move_to(decode_box) self.play( Create(decode_box), Write(decode_title), Write(decode_desc), run_time=1.0 ) self.wait(0.5) # 箭头 arrow = Arrow( prefill_box.get_right(), decode_box.get_left(), buff=0.5, color=YELLOW ) self.play(Create(arrow), run_time=0.5) self.wait(1.5) # 关键点说明 key_point = Text( "KV Cache:避免重复计算,大幅提升效率", font_size=24, font="SimHei", color=YELLOW ) key_point.to_edge(DOWN, buff=1.5) self.play(Write(key_point), run_time=1.0) self.wait(2.0) # 淡出 self.play( FadeOut(title), FadeOut(prefill_box), FadeOut(prefill_title), FadeOut(prefill_desc), FadeOut(decode_box), FadeOut(decode_title), FadeOut(decode_desc), FadeOut(arrow), FadeOut(key_point), run_time=0.8 ) self.wait(0.3) def play_scheduler(self): """调度策略""" title = Text( "调度策略:高效管理KV Cache", font_size=40, font="SimHei", color=WHITE ) title.to_edge(UP, buff=0.8) self.play(Write(title), run_time=1.0) self.wait(0.3) # 连续批处理说明 batch_title = Text( "连续批处理 (Continuous Batching)", font_size=32, font="SimHei", color=BLUE ) batch_title.to_edge(LEFT, buff=1.5).shift(UP * 1.5) self.play(Write(batch_title), run_time=1.0) self.wait(0.3) features = [ "动态批处理:新的请求可以随时加入", "立即释放:完成的任务立即释放资源", "大幅提高GPU利用率", ] feature_texts = [] for i, feature in enumerate(features): feature_text = Text( feature, font_size=22, font="SimHei", color=GRAY ) feature_text.next_to(batch_title, DOWN, buff=0.4 + i * 0.5) feature_text.align_to(batch_title, LEFT) self.play(Write(feature_text), run_time=0.8) self.wait(0.3) feature_texts.append(feature_text) self.wait(1.0) # 前缀缓存 prefix_title = Text( "前缀缓存 (Prefix Caching)", font_size=32, font="SimHei", color=GREEN ) prefix_title.to_edge(RIGHT, buff=1.5).shift(UP * 1.5) self.play(Write(prefix_title), run_time=1.0) self.wait(0.3) prefix_desc = Text( "使用Radix Tree实现\n智能前缀提示加速", font_size=20, font="SimHei", color=GRAY ) prefix_desc.next_to(prefix_title, DOWN, buff=0.4) prefix_desc.align_to(prefix_title, LEFT) self.play(Write(prefix_desc), run_time=0.8) self.wait(2.0) # 淡出 - 简化版本 self.play( FadeOut(title), FadeOut(batch_title), FadeOut(prefix_title), FadeOut(prefix_desc), run_time=0.8 ) for ft in feature_texts: self.play(FadeOut(ft), run_time=0.3) self.wait(0.3) def play_memory_optimization(self): """内存优化""" title = Text( "内存优化:Radix Tree vs PagedAttention", font_size=40, font="SimHei", color=WHITE ) title.to_edge(UP, buff=0.8) self.play(Write(title), run_time=1.0) self.wait(0.3) # Radix Tree 优势 radix_title = Text( "Radix Tree 优势", font_size=28, font="SimHei", color=BLUE ) radix_title.to_edge(LEFT, buff=2.0).shift(UP * 1.0) self.play(Write(radix_title), run_time=1.0) radix_features = [ "自动合并共享前缀的请求", "智能复用已计算的KV Cache", "支持前缀提示加速", "基于LRU的缓存淘汰策略", ] radix_texts = [] for i, feature in enumerate(radix_features): feature_text = Text( f"• {feature}", font_size=20, font="SimHei", color=GRAY ) feature_text.next_to(radix_title, DOWN, buff=0.3 + i * 0.4) feature_text.align_to(radix_title, LEFT) self.play(Write(feature_text), run_time=0.5) radix_texts.append(feature_text) self.wait(1.0) # 对比vLLM compare_title = Text( "对比 vLLM", font_size=28, font="SimHei", color=GREEN ) compare_title.to_edge(RIGHT, buff=2.0).shift(UP * 1.0) self.play(Write(compare_title), run_time=1.0) compare_text = Text( "vLLM使用PagedAttention\nAstrAI使用Radix Tree\n\n两者都能减少内存碎片\n实现方式不同", font_size=18, font="SimHei", color=GRAY ) compare_text.next_to(compare_title, DOWN, buff=0.3) compare_text.align_to(compare_title, LEFT) self.play(Write(compare_text), run_time=1.0) self.wait(2.0) # 淡出 self.play( FadeOut(title), FadeOut(radix_title), FadeOut(compare_title), FadeOut(compare_text), run_time=0.8 ) for rt in radix_texts: self.play(FadeOut(rt), run_time=0.3) self.wait(0.3) def play_distributed_support(self): """分布式支持""" title = Text( "分布式支持:多卡推理", font_size=40, font="SimHei", color=WHITE ) title.to_edge(UP, buff=0.8) self.play(Write(title), run_time=1.0) self.wait(0.3) # 并行方式 parallel_title = Text( "支持的并行方式", font_size=32, font="SimHei", color=BLUE ) parallel_title.to_edge(LEFT, buff=1.5).shift(UP * 1.5) self.play(Write(parallel_title), run_time=1.0) parallel_methods = [ ("数据并行", "多卡处理不同数据"), ("模型并行", "将模型分片到不同卡"), ("流水并行", "按层划分Pipeline"), ] method_texts = [] for i, (method, desc) in enumerate(parallel_methods): method_text = Text( method, font_size=24, font="SimHei", color=GREEN ) method_text.next_to(parallel_title, DOWN, buff=0.4 + i * 0.5) method_text.align_to(parallel_title, LEFT) desc_text = Text( f" - {desc}", font_size=18, font="SimHei", color=GRAY ) desc_text.next_to(method_text, RIGHT, buff=0.3) self.play(Write(method_text), Write(desc_text), run_time=0.8) method_texts.append((method_text, desc_text)) self.wait(1.5) # 代码示例 code_title = Text( "分布式初始化", font_size=28, font="SimHei", color=ORANGE ) code_title.to_edge(RIGHT, buff=1.5).shift(UP * 1.5) self.play(Write(code_title), run_time=1.0) code_text = Text( "setup_parallel()\nspawn_parallel_fn()\n\n支持NCCL后端", font_size=16, font="SimHei", color=GRAY ) code_text.next_to(code_title, DOWN, buff=0.3) code_text.align_to(code_title, LEFT) self.play(Write(code_text), run_time=1.0) self.wait(2.0) # 淡出 self.play( FadeOut(title), FadeOut(parallel_title), FadeOut(code_title), FadeOut(code_text), run_time=0.8 ) for mt, dt in method_texts: self.play(FadeOut(mt), FadeOut(dt), run_time=0.3) self.wait(0.5)