get_rope 使用指南¶
概述¶
get_rope 提供了一个灵活的注册机制来创建和管理不同类型的 Rotary Position Embedding (RoPE) 实例。通过注册机制,模型特定的 RoPE 实现可以放在各自的模型文件中,而不是集中在工厂类中。
核心特性¶
- 注册机制:通过
@register_rope_type装饰器注册自定义 RoPE 类型 - 自动缓存:相同配置的 RoPE 实例会被自动缓存,避免重复创建
- 模型特定支持:模型特定的外推方式(如 DeepseekV3YarnRotaryEmbedding)可以在模型文件中注册
使用方式¶
1. 使用默认或者已经注册 RoPE¶
from mindie_llm.runtime.layers.embedding.rotary_embedding import get_rope
self.rope_emb = get_rope(
self.head_dim,
self.head_dim,
self.config.rope_scaling.max_position_embeddings,
is_neox_style=True,
rope_config=config.rope_scaling,
)
...
# 使用方式
# 根据positions设置cos_sin_indexed_cache
self.layers[0].self_attn.rope_emb.set_cos_sin_indexed_cache(positions)
...
# 1. 调用forward直接对query,key进行rope变换
query, key = self.rope_emb(positions, query, key)
...
# 2. 直接拿出cos, sin交给attention 后端使用
return self.attn(hidden_states,
cos=self.rope_emb.cos_indexed_cache,
sin=self.rope_emb.sin_indexed_cache)
2. 实现模型特定的 RoPE 类型 (以deepseekv3为例)¶
2.1 rope模块实现¶
在模型目录下定义自己的rope实现(例如 mindie_llm/runtime/layers/embedding/rotary_embedding/deepseek_v3_yarn_scaling_rope.py):
可以选择继承mindie_llm/runtime/layers/embedding/rotary_embedding/base.py下的RotaryEmbedding 或者继承mindie_llm/runtime/layers/embedding/rotary_embedding/yarn_scaling_rope.py 下的YarnScalingRotaryEmbedding用于外推
from mindie_llm.runtime.layers.embedding.rotary_embedding.yarn_scaling_rope import (
YarnScalingRotaryEmbedding,
yarn_get_mscale
)
class DeepseekV3YarnRotaryEmbedding(YarnScalingRotaryEmbedding):
"""DeepSeek-V3 specialized YaRN rotary embedding with mscale_all_dim scaling.
Extends standard YaRN scaling with DeepSeek-V3's additional magnitude scaling
parameter (mscale_all_dim) for fine-grained attention magnitude control.
"""
def __init__(
self,
dim,
original_max_position_embeddings=4096,
base=10000,
factor=1.0,
beta_fast=32,
beta_slow=1,
is_neox_style=True,
dtype=None,
mscale=1.0,
mscale_all_dim=1.0,
) -> None:
"""Initialize DeepSeek-V3 YaRN rotary embedding.
Args:
dim: Rotary embedding dimension (applied to both head and rotary dims).
original_max_position_embeddings: Original context length before scaling.
base: Base frequency for rotary embedding (theta).
factor: Context extension scaling factor (>1.0 for extrapolation).
beta_fast: YaRN fast decay window parameter.
beta_slow: YaRN slow decay window parameter.
is_neox_style: Use NeoX-style interleaved rotation (default: True).
dtype: Data type for embedding tensors (e.g., torch.float16).
mscale: Base magnitude scaling factor for attention preservation.
mscale_all_dim: DeepSeek-V3 specific scaling factor applied across all dimensions.
"""
self.mscale_all_dim = mscale_all_dim
super().__init__(dim, dim, original_max_position_embeddings, base,
dtype=dtype,
is_neox_style=is_neox_style,
factor=factor,
beta_fast=beta_fast,
beta_slow=beta_slow,
mscale=mscale
)
def set_cos_sin_indexed_cache(self, positions) -> None:
"""Create position-indexed cosine/sine caches with dimension doubling.
Extracts position-specific rotary values from precomputed caches and
duplicates them across the last dimension to match attention head layout.
Args:
positions: 1D tensor of position indices to index into the cache.
"""
cos_indexed_cache = torch.index_select(self.cos_cache, dim=0, index=positions.view(-1)).unsqueeze(1).unsqueeze(1)
sin_indexed_cache = torch.index_select(self.sin_cache, dim=0, index=positions.view(-1)).unsqueeze(1).unsqueeze(1)
cos_indexed_cache = torch.cat((cos_indexed_cache, cos_indexed_cache), dim=-1)
sin_indexed_cache = torch.cat((sin_indexed_cache, sin_indexed_cache), dim=-1)
self.register_buffer("cos_indexed_cache", cos_indexed_cache, persistent=False) # [seq_len, 1, 1, rotary_dim]
self.register_buffer("sin_indexed_cache", sin_indexed_cache, persistent=False)
def _compute_cos_sin_cache(self) -> None:
"""Precompute cosine/sine caches with DeepSeek-V3 specific magnitude scaling.
Applies dual scaling factors (mscale and mscale_all_dim) to preserve attention
magnitude during context extrapolation. The effective scale is mscale/mscale_all_dim.
"""
t = torch.arange(
self.max_position_embeddings
).to(torch.float32)
freqs = torch.einsum("i,j -> ij", t, self.inv_freq)
_mscale = float(
yarn_get_mscale(self.scaling_factor, self.mscale)
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
)
cos = freqs.cos().to(self.dtype) * _mscale
sin = freqs.sin().to(self.dtype) * _mscale
self.register_buffer("cos_cache", cos, persistent=False) # [max_position_embeddings, rotary_dim // 2]
self.register_buffer("sin_cache", sin, persistent=False) # [max_position_embeddings, rotary_dim // 2]
2.2 实现自定义的rope构造函数并注册¶
注册函数必须使用装饰器@register_rope_type("xxxx") @cached_rope_factory:
@register_rope_type("deepseek_yarn")
@cached_rope_factory
def _create_deepseek_scaling_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
rope_config: RopeScaling,
) -> RotaryEmbedding:
"""Factory function for creating DeepSeek-V3 YaRN-scaled RotaryEmbedding.
Specialized implementation for DeepSeek-V3 architecture with YaRN scaling
and DeepSeek-specific parameters like mscale_all_dim.
Args:
head_size: Dimension of each attention head.
rotary_dim: Dimensionality of the rotary embedding subspace.
max_position: Target maximum sequence length after scaling.
base: Base value for frequency computation (theta).
is_neox_style: Whether to use NeoX-style interleaved rotation.
dtype: Data type for embedding tensors.
rope_config: Configuration object containing DeepSeek-specific parameters:
- original_max_position_embeddings: Original context length before scaling
- factor: Scaling factor for context extension
- beta_fast/beta_slow: YaRN attention window parameters
- mscale: Magnitude scaling factor
- mscale_all_dim: DeepSeek-specific magnitude scaling dimension parameter
Returns:
Initialized DeepseekV3YarnRotaryEmbedding instance.
"""
ds_yarn_extra_keys = (
"factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim"
)
extra_kwargs = {
k: getattr(rope_config, k)
for k in ds_yarn_extra_keys
}
return DeepseekV3YarnRotaryEmbedding(
rotary_dim,
rope_config.original_max_position_embeddings,
base,
is_neox_style=is_neox_style,
dtype=dtype,
**extra_kwargs,
)
缓存机制¶
get_rope 会自动缓存相同配置的 RoPE 实例。缓存键基于:
head_sizerotary_dim(由head_size * partial_rotary_factor计算)max_positionis_neox_stylebaserope_config(列表会被转换为元组以确保稳定性)dtype
这意味着相同配置的多次调用会返回同一个实例,节省内存和计算资源。
已注册的类型¶
default: 标准 RotaryEmbedding(默认)yarn: YarnScalingRotaryEmbeddingdeepseek_yarn: DeepseekV3YarnRotaryEmbedding
注意事项¶
-
模型特定rope实现应在mindie_llm/runtime/layers/embedding/rotary_embedding目录下单独文件:如
DeepseekV3YarnRotaryEmbedding应该在mindie_llm/runtime/layers/embedding/rotary_embedding目录下的文件中注册,而不是在rotary_embedding/__init__.py中。 -
注册时机:确保在使用
get_rope之前完成注册。通常这发生在模块导入时。 -
参数提取:注册函数应该从
rope_config中提取所需的参数,而不是期望所有参数都通过位置参数传递。 -
向后兼容:现有的代码无需修改即可继续工作,如果新增模型需要新的rope,请单独实现自己的rope模块并注册使用,做增量修改,不要修改原有代码,否则须测试相关场景。