AstrAI-video-repo/part2_core_design.py

470 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AstrAI项目介绍视频 - Part 2: AstrAI的核心设计
"""
from manim import *
class Part2_CoreDesign(Scene):
"""第二部分AstrAI的核心设计"""
def construct(self):
# 2.1 整体架构
self.play_architecture()
# 2.2 推理流程
self.play_inference_flow()
# 2.3 调度策略
self.play_scheduler()
# 2.4 内存优化
self.play_memory_optimization()
# 2.5 分布式支持
self.play_distributed_support()
def play_architecture(self):
"""整体架构展示"""
# 标题
title = Text(
"AstrAI 核心设计",
font_size=48,
font="SimHei",
color=WHITE
)
title.to_edge(UP, buff=0.8)
self.play(Write(title), run_time=1.0)
self.wait(0.5)
# 模块表格
modules = [
("astrai.config", "配置管理"),
("astrai.dataset", "数据集加载"),
("astrai.model", "神经网络模型"),
("astrai.tokenize", "分词器和聊天模板"),
("astrai.trainer", "训练工作流管理"),
("astrai.inference", "推理调度"),
("astrai.parallel", "分布式并行支持"),
]
# 创建模块列表
module_objects = []
for i, (module, desc) in enumerate(modules):
module_text = Text(
module,
font_size=28,
font="SimHei",
color=BLUE
)
desc_text = Text(
desc,
font_size=20,
font="SimHei",
color=GRAY
)
module_text.shift(LEFT * 3)
desc_text.next_to(module_text, RIGHT, buff=0.5)
if i == 0:
module_text.to_edge(UP, buff=2.5)
else:
module_text.next_to(module_objects[-1][0], DOWN, buff=0.4)
desc_text.next_to(module_text, RIGHT, buff=0.5)
self.play(Write(module_text), run_time=0.5)
self.play(Write(desc_text), run_time=0.3)
module_objects.append((module_text, desc_text))
if i < len(modules) - 1:
self.wait(0.2)
self.wait(2.0)
# 淡出
self.play(
FadeOut(title),
*[FadeOut(m) for m, d in module_objects],
*[FadeOut(d) for m, d in module_objects],
run_time=0.8
)
self.wait(0.3)
def play_inference_flow(self):
"""推理流程Prefill → Decode"""
title = Text(
"推理流程Prefill → Decode",
font_size=40,
font="SimHei",
color=WHITE
)
title.to_edge(UP, buff=0.8)
self.play(Write(title), run_time=1.0)
self.wait(0.3)
# Pre-fill阶段
prefill_box = RoundedRectangle(
width=3.5,
height=1.5,
corner_radius=0.2,
color=BLUE,
fill_opacity=0.3
)
prefill_box.to_edge(LEFT, buff=2.0)
prefill_title = Text(
"Pre-fill",
font_size=28,
font="SimHei",
color=BLUE
)
prefill_title.next_to(prefill_box, UP, buff=0.2)
prefill_desc = Text(
"一次性处理输入序列\n计算所有Token的K和V",
font_size=16,
font="SimHei",
color=GRAY
)
prefill_desc.move_to(prefill_box)
self.play(
Create(prefill_box),
Write(prefill_title),
Write(prefill_desc),
run_time=1.0
)
self.wait(0.5)
# Decode阶段
decode_box = RoundedRectangle(
width=3.5,
height=1.5,
corner_radius=0.2,
color=GREEN,
fill_opacity=0.3
)
decode_box.to_edge(RIGHT, buff=2.0)
decode_title = Text(
"Decode",
font_size=28,
font="SimHei",
color=GREEN
)
decode_title.next_to(decode_box, UP, buff=0.2)
decode_desc = Text(
"生成新Token\n从KV Cache读取K和V",
font_size=16,
font="SimHei",
color=GRAY
)
decode_desc.move_to(decode_box)
self.play(
Create(decode_box),
Write(decode_title),
Write(decode_desc),
run_time=1.0
)
self.wait(0.5)
# 箭头
arrow = Arrow(
prefill_box.get_right(),
decode_box.get_left(),
buff=0.5,
color=YELLOW
)
self.play(Create(arrow), run_time=0.5)
self.wait(1.5)
# 关键点说明
key_point = Text(
"KV Cache避免重复计算大幅提升效率",
font_size=24,
font="SimHei",
color=YELLOW
)
key_point.to_edge(DOWN, buff=1.5)
self.play(Write(key_point), run_time=1.0)
self.wait(2.0)
# 淡出
self.play(
FadeOut(title),
FadeOut(prefill_box),
FadeOut(prefill_title),
FadeOut(prefill_desc),
FadeOut(decode_box),
FadeOut(decode_title),
FadeOut(decode_desc),
FadeOut(arrow),
FadeOut(key_point),
run_time=0.8
)
self.wait(0.3)
def play_scheduler(self):
"""调度策略"""
title = Text(
"调度策略高效管理KV Cache",
font_size=40,
font="SimHei",
color=WHITE
)
title.to_edge(UP, buff=0.8)
self.play(Write(title), run_time=1.0)
self.wait(0.3)
# 连续批处理说明
batch_title = Text(
"连续批处理 (Continuous Batching)",
font_size=32,
font="SimHei",
color=BLUE
)
batch_title.to_edge(LEFT, buff=1.5).shift(UP * 1.5)
self.play(Write(batch_title), run_time=1.0)
self.wait(0.3)
features = [
"动态批处理:新的请求可以随时加入",
"立即释放:完成的任务立即释放资源",
"大幅提高GPU利用率",
]
feature_texts = []
for i, feature in enumerate(features):
feature_text = Text(
feature,
font_size=22,
font="SimHei",
color=GRAY
)
feature_text.next_to(batch_title, DOWN, buff=0.4 + i * 0.5)
feature_text.align_to(batch_title, LEFT)
self.play(Write(feature_text), run_time=0.8)
self.wait(0.3)
feature_texts.append(feature_text)
self.wait(1.0)
# 前缀缓存
prefix_title = Text(
"前缀缓存 (Prefix Caching)",
font_size=32,
font="SimHei",
color=GREEN
)
prefix_title.to_edge(RIGHT, buff=1.5).shift(UP * 1.5)
self.play(Write(prefix_title), run_time=1.0)
self.wait(0.3)
prefix_desc = Text(
"使用Radix Tree实现\n智能前缀提示加速",
font_size=20,
font="SimHei",
color=GRAY
)
prefix_desc.next_to(prefix_title, DOWN, buff=0.4)
prefix_desc.align_to(prefix_title, LEFT)
self.play(Write(prefix_desc), run_time=0.8)
self.wait(2.0)
# 淡出 - 简化版本
self.play(
FadeOut(title),
FadeOut(batch_title),
FadeOut(prefix_title),
FadeOut(prefix_desc),
run_time=0.8
)
for ft in feature_texts:
self.play(FadeOut(ft), run_time=0.3)
self.wait(0.3)
def play_memory_optimization(self):
"""内存优化"""
title = Text(
"内存优化Radix Tree vs PagedAttention",
font_size=40,
font="SimHei",
color=WHITE
)
title.to_edge(UP, buff=0.8)
self.play(Write(title), run_time=1.0)
self.wait(0.3)
# Radix Tree 优势
radix_title = Text(
"Radix Tree 优势",
font_size=28,
font="SimHei",
color=BLUE
)
radix_title.to_edge(LEFT, buff=2.0).shift(UP * 1.0)
self.play(Write(radix_title), run_time=1.0)
radix_features = [
"自动合并共享前缀的请求",
"智能复用已计算的KV Cache",
"支持前缀提示加速",
"基于LRU的缓存淘汰策略",
]
radix_texts = []
for i, feature in enumerate(radix_features):
feature_text = Text(
f"{feature}",
font_size=20,
font="SimHei",
color=GRAY
)
feature_text.next_to(radix_title, DOWN, buff=0.3 + i * 0.4)
feature_text.align_to(radix_title, LEFT)
self.play(Write(feature_text), run_time=0.5)
radix_texts.append(feature_text)
self.wait(1.0)
# 对比vLLM
compare_title = Text(
"对比 vLLM",
font_size=28,
font="SimHei",
color=GREEN
)
compare_title.to_edge(RIGHT, buff=2.0).shift(UP * 1.0)
self.play(Write(compare_title), run_time=1.0)
compare_text = Text(
"vLLM使用PagedAttention\nAstrAI使用Radix Tree\n\n两者都能减少内存碎片\n实现方式不同",
font_size=18,
font="SimHei",
color=GRAY
)
compare_text.next_to(compare_title, DOWN, buff=0.3)
compare_text.align_to(compare_title, LEFT)
self.play(Write(compare_text), run_time=1.0)
self.wait(2.0)
# 淡出
self.play(
FadeOut(title),
FadeOut(radix_title),
FadeOut(compare_title),
FadeOut(compare_text),
run_time=0.8
)
for rt in radix_texts:
self.play(FadeOut(rt), run_time=0.3)
self.wait(0.3)
def play_distributed_support(self):
"""分布式支持"""
title = Text(
"分布式支持:多卡推理",
font_size=40,
font="SimHei",
color=WHITE
)
title.to_edge(UP, buff=0.8)
self.play(Write(title), run_time=1.0)
self.wait(0.3)
# 并行方式
parallel_title = Text(
"支持的并行方式",
font_size=32,
font="SimHei",
color=BLUE
)
parallel_title.to_edge(LEFT, buff=1.5).shift(UP * 1.5)
self.play(Write(parallel_title), run_time=1.0)
parallel_methods = [
("数据并行", "多卡处理不同数据"),
("模型并行", "将模型分片到不同卡"),
("流水并行", "按层划分Pipeline"),
]
method_texts = []
for i, (method, desc) in enumerate(parallel_methods):
method_text = Text(
method,
font_size=24,
font="SimHei",
color=GREEN
)
method_text.next_to(parallel_title, DOWN, buff=0.4 + i * 0.5)
method_text.align_to(parallel_title, LEFT)
desc_text = Text(
f" - {desc}",
font_size=18,
font="SimHei",
color=GRAY
)
desc_text.next_to(method_text, RIGHT, buff=0.3)
self.play(Write(method_text), Write(desc_text), run_time=0.8)
method_texts.append((method_text, desc_text))
self.wait(1.5)
# 代码示例
code_title = Text(
"分布式初始化",
font_size=28,
font="SimHei",
color=ORANGE
)
code_title.to_edge(RIGHT, buff=1.5).shift(UP * 1.5)
self.play(Write(code_title), run_time=1.0)
code_text = Text(
"setup_parallel()\nspawn_parallel_fn()\n\n支持NCCL后端",
font_size=16,
font="SimHei",
color=GRAY
)
code_text.next_to(code_title, DOWN, buff=0.3)
code_text.align_to(code_title, LEFT)
self.play(Write(code_text), run_time=1.0)
self.wait(2.0)
# 淡出
self.play(
FadeOut(title),
FadeOut(parallel_title),
FadeOut(code_title),
FadeOut(code_text),
run_time=0.8
)
for mt, dt in method_texts:
self.play(FadeOut(mt), FadeOut(dt), run_time=0.3)
self.wait(0.5)