470 lines
13 KiB
Python
470 lines
13 KiB
Python
"""
|
||
AstrAI项目介绍视频 - Part 2: AstrAI的核心设计
|
||
"""
|
||
from manim import *
|
||
|
||
|
||
class Part2_CoreDesign(Scene):
|
||
"""第二部分:AstrAI的核心设计"""
|
||
|
||
def construct(self):
|
||
# 2.1 整体架构
|
||
self.play_architecture()
|
||
|
||
# 2.2 推理流程
|
||
self.play_inference_flow()
|
||
|
||
# 2.3 调度策略
|
||
self.play_scheduler()
|
||
|
||
# 2.4 内存优化
|
||
self.play_memory_optimization()
|
||
|
||
# 2.5 分布式支持
|
||
self.play_distributed_support()
|
||
|
||
def play_architecture(self):
|
||
"""整体架构展示"""
|
||
# 标题
|
||
title = Text(
|
||
"AstrAI 核心设计",
|
||
font_size=48,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
title.to_edge(UP, buff=0.8)
|
||
|
||
self.play(Write(title), run_time=1.0)
|
||
self.wait(0.5)
|
||
|
||
# 模块表格
|
||
modules = [
|
||
("astrai.config", "配置管理"),
|
||
("astrai.dataset", "数据集加载"),
|
||
("astrai.model", "神经网络模型"),
|
||
("astrai.tokenize", "分词器和聊天模板"),
|
||
("astrai.trainer", "训练工作流管理"),
|
||
("astrai.inference", "推理调度"),
|
||
("astrai.parallel", "分布式并行支持"),
|
||
]
|
||
|
||
# 创建模块列表
|
||
module_objects = []
|
||
for i, (module, desc) in enumerate(modules):
|
||
module_text = Text(
|
||
module,
|
||
font_size=28,
|
||
font="SimHei",
|
||
color=BLUE
|
||
)
|
||
desc_text = Text(
|
||
desc,
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
module_text.shift(LEFT * 3)
|
||
desc_text.next_to(module_text, RIGHT, buff=0.5)
|
||
|
||
if i == 0:
|
||
module_text.to_edge(UP, buff=2.5)
|
||
else:
|
||
module_text.next_to(module_objects[-1][0], DOWN, buff=0.4)
|
||
desc_text.next_to(module_text, RIGHT, buff=0.5)
|
||
|
||
self.play(Write(module_text), run_time=0.5)
|
||
self.play(Write(desc_text), run_time=0.3)
|
||
module_objects.append((module_text, desc_text))
|
||
|
||
if i < len(modules) - 1:
|
||
self.wait(0.2)
|
||
|
||
self.wait(2.0)
|
||
|
||
# 淡出
|
||
self.play(
|
||
FadeOut(title),
|
||
*[FadeOut(m) for m, d in module_objects],
|
||
*[FadeOut(d) for m, d in module_objects],
|
||
run_time=0.8
|
||
)
|
||
self.wait(0.3)
|
||
|
||
def play_inference_flow(self):
|
||
"""推理流程:Prefill → Decode"""
|
||
title = Text(
|
||
"推理流程:Prefill → Decode",
|
||
font_size=40,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
title.to_edge(UP, buff=0.8)
|
||
|
||
self.play(Write(title), run_time=1.0)
|
||
self.wait(0.3)
|
||
|
||
# Pre-fill阶段
|
||
prefill_box = RoundedRectangle(
|
||
width=3.5,
|
||
height=1.5,
|
||
corner_radius=0.2,
|
||
color=BLUE,
|
||
fill_opacity=0.3
|
||
)
|
||
prefill_box.to_edge(LEFT, buff=2.0)
|
||
|
||
prefill_title = Text(
|
||
"Pre-fill",
|
||
font_size=28,
|
||
font="SimHei",
|
||
color=BLUE
|
||
)
|
||
prefill_title.next_to(prefill_box, UP, buff=0.2)
|
||
|
||
prefill_desc = Text(
|
||
"一次性处理输入序列\n计算所有Token的K和V",
|
||
font_size=16,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
prefill_desc.move_to(prefill_box)
|
||
|
||
self.play(
|
||
Create(prefill_box),
|
||
Write(prefill_title),
|
||
Write(prefill_desc),
|
||
run_time=1.0
|
||
)
|
||
self.wait(0.5)
|
||
|
||
# Decode阶段
|
||
decode_box = RoundedRectangle(
|
||
width=3.5,
|
||
height=1.5,
|
||
corner_radius=0.2,
|
||
color=GREEN,
|
||
fill_opacity=0.3
|
||
)
|
||
decode_box.to_edge(RIGHT, buff=2.0)
|
||
|
||
decode_title = Text(
|
||
"Decode",
|
||
font_size=28,
|
||
font="SimHei",
|
||
color=GREEN
|
||
)
|
||
decode_title.next_to(decode_box, UP, buff=0.2)
|
||
|
||
decode_desc = Text(
|
||
"生成新Token\n从KV Cache读取K和V",
|
||
font_size=16,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
decode_desc.move_to(decode_box)
|
||
|
||
self.play(
|
||
Create(decode_box),
|
||
Write(decode_title),
|
||
Write(decode_desc),
|
||
run_time=1.0
|
||
)
|
||
self.wait(0.5)
|
||
|
||
# 箭头
|
||
arrow = Arrow(
|
||
prefill_box.get_right(),
|
||
decode_box.get_left(),
|
||
buff=0.5,
|
||
color=YELLOW
|
||
)
|
||
self.play(Create(arrow), run_time=0.5)
|
||
self.wait(1.5)
|
||
|
||
# 关键点说明
|
||
key_point = Text(
|
||
"KV Cache:避免重复计算,大幅提升效率",
|
||
font_size=24,
|
||
font="SimHei",
|
||
color=YELLOW
|
||
)
|
||
key_point.to_edge(DOWN, buff=1.5)
|
||
|
||
self.play(Write(key_point), run_time=1.0)
|
||
self.wait(2.0)
|
||
|
||
# 淡出
|
||
self.play(
|
||
FadeOut(title),
|
||
FadeOut(prefill_box),
|
||
FadeOut(prefill_title),
|
||
FadeOut(prefill_desc),
|
||
FadeOut(decode_box),
|
||
FadeOut(decode_title),
|
||
FadeOut(decode_desc),
|
||
FadeOut(arrow),
|
||
FadeOut(key_point),
|
||
run_time=0.8
|
||
)
|
||
self.wait(0.3)
|
||
|
||
def play_scheduler(self):
|
||
"""调度策略"""
|
||
title = Text(
|
||
"调度策略:高效管理KV Cache",
|
||
font_size=40,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
title.to_edge(UP, buff=0.8)
|
||
|
||
self.play(Write(title), run_time=1.0)
|
||
self.wait(0.3)
|
||
|
||
# 连续批处理说明
|
||
batch_title = Text(
|
||
"连续批处理 (Continuous Batching)",
|
||
font_size=32,
|
||
font="SimHei",
|
||
color=BLUE
|
||
)
|
||
batch_title.to_edge(LEFT, buff=1.5).shift(UP * 1.5)
|
||
|
||
self.play(Write(batch_title), run_time=1.0)
|
||
self.wait(0.3)
|
||
|
||
features = [
|
||
"动态批处理:新的请求可以随时加入",
|
||
"立即释放:完成的任务立即释放资源",
|
||
"大幅提高GPU利用率",
|
||
]
|
||
|
||
feature_texts = []
|
||
for i, feature in enumerate(features):
|
||
feature_text = Text(
|
||
feature,
|
||
font_size=22,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
feature_text.next_to(batch_title, DOWN, buff=0.4 + i * 0.5)
|
||
feature_text.align_to(batch_title, LEFT)
|
||
|
||
self.play(Write(feature_text), run_time=0.8)
|
||
self.wait(0.3)
|
||
feature_texts.append(feature_text)
|
||
|
||
self.wait(1.0)
|
||
|
||
# 前缀缓存
|
||
prefix_title = Text(
|
||
"前缀缓存 (Prefix Caching)",
|
||
font_size=32,
|
||
font="SimHei",
|
||
color=GREEN
|
||
)
|
||
prefix_title.to_edge(RIGHT, buff=1.5).shift(UP * 1.5)
|
||
|
||
self.play(Write(prefix_title), run_time=1.0)
|
||
self.wait(0.3)
|
||
|
||
prefix_desc = Text(
|
||
"使用Radix Tree实现\n智能前缀提示加速",
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
prefix_desc.next_to(prefix_title, DOWN, buff=0.4)
|
||
prefix_desc.align_to(prefix_title, LEFT)
|
||
|
||
self.play(Write(prefix_desc), run_time=0.8)
|
||
self.wait(2.0)
|
||
|
||
# 淡出 - 简化版本
|
||
self.play(
|
||
FadeOut(title),
|
||
FadeOut(batch_title),
|
||
FadeOut(prefix_title),
|
||
FadeOut(prefix_desc),
|
||
run_time=0.8
|
||
)
|
||
|
||
for ft in feature_texts:
|
||
self.play(FadeOut(ft), run_time=0.3)
|
||
|
||
self.wait(0.3)
|
||
|
||
def play_memory_optimization(self):
|
||
"""内存优化"""
|
||
title = Text(
|
||
"内存优化:Radix Tree vs PagedAttention",
|
||
font_size=40,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
title.to_edge(UP, buff=0.8)
|
||
|
||
self.play(Write(title), run_time=1.0)
|
||
self.wait(0.3)
|
||
|
||
# Radix Tree 优势
|
||
radix_title = Text(
|
||
"Radix Tree 优势",
|
||
font_size=28,
|
||
font="SimHei",
|
||
color=BLUE
|
||
)
|
||
radix_title.to_edge(LEFT, buff=2.0).shift(UP * 1.0)
|
||
|
||
self.play(Write(radix_title), run_time=1.0)
|
||
|
||
radix_features = [
|
||
"自动合并共享前缀的请求",
|
||
"智能复用已计算的KV Cache",
|
||
"支持前缀提示加速",
|
||
"基于LRU的缓存淘汰策略",
|
||
]
|
||
|
||
radix_texts = []
|
||
for i, feature in enumerate(radix_features):
|
||
feature_text = Text(
|
||
f"• {feature}",
|
||
font_size=20,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
feature_text.next_to(radix_title, DOWN, buff=0.3 + i * 0.4)
|
||
feature_text.align_to(radix_title, LEFT)
|
||
|
||
self.play(Write(feature_text), run_time=0.5)
|
||
radix_texts.append(feature_text)
|
||
|
||
self.wait(1.0)
|
||
|
||
# 对比vLLM
|
||
compare_title = Text(
|
||
"对比 vLLM",
|
||
font_size=28,
|
||
font="SimHei",
|
||
color=GREEN
|
||
)
|
||
compare_title.to_edge(RIGHT, buff=2.0).shift(UP * 1.0)
|
||
|
||
self.play(Write(compare_title), run_time=1.0)
|
||
|
||
compare_text = Text(
|
||
"vLLM使用PagedAttention\nAstrAI使用Radix Tree\n\n两者都能减少内存碎片\n实现方式不同",
|
||
font_size=18,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
compare_text.next_to(compare_title, DOWN, buff=0.3)
|
||
compare_text.align_to(compare_title, LEFT)
|
||
|
||
self.play(Write(compare_text), run_time=1.0)
|
||
self.wait(2.0)
|
||
|
||
# 淡出
|
||
self.play(
|
||
FadeOut(title),
|
||
FadeOut(radix_title),
|
||
FadeOut(compare_title),
|
||
FadeOut(compare_text),
|
||
run_time=0.8
|
||
)
|
||
|
||
for rt in radix_texts:
|
||
self.play(FadeOut(rt), run_time=0.3)
|
||
|
||
self.wait(0.3)
|
||
|
||
def play_distributed_support(self):
|
||
"""分布式支持"""
|
||
title = Text(
|
||
"分布式支持:多卡推理",
|
||
font_size=40,
|
||
font="SimHei",
|
||
color=WHITE
|
||
)
|
||
title.to_edge(UP, buff=0.8)
|
||
|
||
self.play(Write(title), run_time=1.0)
|
||
self.wait(0.3)
|
||
|
||
# 并行方式
|
||
parallel_title = Text(
|
||
"支持的并行方式",
|
||
font_size=32,
|
||
font="SimHei",
|
||
color=BLUE
|
||
)
|
||
parallel_title.to_edge(LEFT, buff=1.5).shift(UP * 1.5)
|
||
|
||
self.play(Write(parallel_title), run_time=1.0)
|
||
|
||
parallel_methods = [
|
||
("数据并行", "多卡处理不同数据"),
|
||
("模型并行", "将模型分片到不同卡"),
|
||
("流水并行", "按层划分Pipeline"),
|
||
]
|
||
|
||
method_texts = []
|
||
for i, (method, desc) in enumerate(parallel_methods):
|
||
method_text = Text(
|
||
method,
|
||
font_size=24,
|
||
font="SimHei",
|
||
color=GREEN
|
||
)
|
||
method_text.next_to(parallel_title, DOWN, buff=0.4 + i * 0.5)
|
||
method_text.align_to(parallel_title, LEFT)
|
||
|
||
desc_text = Text(
|
||
f" - {desc}",
|
||
font_size=18,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
desc_text.next_to(method_text, RIGHT, buff=0.3)
|
||
|
||
self.play(Write(method_text), Write(desc_text), run_time=0.8)
|
||
method_texts.append((method_text, desc_text))
|
||
|
||
self.wait(1.5)
|
||
|
||
# 代码示例
|
||
code_title = Text(
|
||
"分布式初始化",
|
||
font_size=28,
|
||
font="SimHei",
|
||
color=ORANGE
|
||
)
|
||
code_title.to_edge(RIGHT, buff=1.5).shift(UP * 1.5)
|
||
|
||
self.play(Write(code_title), run_time=1.0)
|
||
|
||
code_text = Text(
|
||
"setup_parallel()\nspawn_parallel_fn()\n\n支持NCCL后端",
|
||
font_size=16,
|
||
font="SimHei",
|
||
color=GRAY
|
||
)
|
||
code_text.next_to(code_title, DOWN, buff=0.3)
|
||
code_text.align_to(code_title, LEFT)
|
||
|
||
self.play(Write(code_text), run_time=1.0)
|
||
self.wait(2.0)
|
||
|
||
# 淡出
|
||
self.play(
|
||
FadeOut(title),
|
||
FadeOut(parallel_title),
|
||
FadeOut(code_title),
|
||
FadeOut(code_text),
|
||
run_time=0.8
|
||
)
|
||
|
||
for mt, dt in method_texts:
|
||
self.play(FadeOut(mt), FadeOut(dt), run_time=0.3)
|
||
|
||
self.wait(0.5) |