添加新模型#

本指南演示如何将新颖或自定义的模型集成到 vllm-ascend 中。对于基础概念,强烈建议先参考 vllm 官方文档:添加新模型

步骤 1:使用 torchtorch_npu 实现模型#

本节提供了实现与 vllm 和 vllm-ascend 兼容的新模型的相关说明。

开始之前:

  • 请确认你的模型是否已经存在于 vllm 的 models 目录中。

  • 使用已有模型的实现作为模板以加速您的开发。

方法一:从零开始实现新模型#

请参考 vllm 的 OPT 模型适配 示例进行操作。

关键实现要求:

  1. 请将模型文件放在 vllm_ascend/models/ 目录下。

  2. 解码器-only LLMs 的标准模块结构(请参考 vllm 对其他类型模型的实现):

  • *ModelForCausalLM(顶层包装器)

  • *Model(主架构)

  • *DecoderLayer (transformer 块)

  • *Attention*MLP(特定计算单元)

备注

* 表示你的模型的唯一标识符。

  1. 关键实现细节:

所有模块在 __init__() 方法中都必须包含一个 prefix 参数。

必需的接口:

模块类型

必需的方法

*ModelForCausalLM

get_input_embeddingscompute_logitsload_weights

*模型

get_input_embeddingsload_weights

  1. 注意后端集成:

通过 from vllm.attention import Attention 导入 attention 可以自动利用 vllm-ascend 的注意力后端路由(详见:vllm_ascend/platform.py 中的 get_attn_backend_cls())。

  1. 张量并行:

使用 vllm 的并行层(如 ColumnParallelLinearVocabParallelEmbedding 等)来实现支持张量并行的模型。需要注意的是,Ascend 特有的自定义实现(如 RMSNorm、VocabParallelEmbedding 等)位于 vllm_ascend/ops/ 目录下。

参考实现模板(假定路径:vllm_ascend/models/custom_model.py):

from collections.abc import Iterable
from typing import Optional, Union

import torch
from torch import nn
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.sequence import IntermediateTensors
from vllm.model_executor.sampling_metadata import SamplingMetadata

class CustomAttention(nn.Module):
    def __init__(self, vllm_config: VllmConfig, prefix: str):
        super().__init__()
        self.attn = Attention(prefix=f"{prefix}.attn")

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # Implement attention logic
        ...

class CustomDecoderLayer(nn.Module):
    def __init__(self, vllm_config: VllmConfig, prefix: str):
        super().__init__()
        self.self_attn = CustomAttention(vllm_config, prefix=f"{prefix}.self_attn")

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # Implement decoder layer
        ...

class CustomModel(nn.Module):
    def __init__(self, vllm_config: VllmConfig, prefix: str):
        super().__init__()
        self.layers = nn.ModuleList([
            CustomDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") 
            for i in range(vllm_config.model_config.hf_config.num_hidden_layers)
        ])

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        ...

    def load_weights(self, 
                    weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        ...

class CustomModelForCausalLM(nn.Module):
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.model = CustomModel(vllm_config, prefix=f"{prefix}.model")

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        ...

    def compute_logits(self,
                      hidden_states: torch.Tensor,
                      sampling_metadata: SamplingMetadata) -> torch.Tensor:
        ...

    def load_weights(self, 
                    weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        ...

方法二:自定义已有的 vLLM 模型#

对于大多数使用场景,建议扩展已有的实现。我们在下面演示了一个示例,通过继承基类并实现一个自定义的 deepseek 模型(假定路径:vllm_ascend/models/deepseek_v2.py)。

from typing import List, Optional
import torch
from vllm.attention import AttentionMetadata
from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM
from vllm.sequence import IntermediateTensors

class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
    # Define merged weights for quantization/efficiency
    packed_modules_mapping = {
        "gate_up_proj": ["gate_proj", "up_proj"],
        "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
    }

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: Optional[List[torch.Tensor]] = None,
        attn_metadata: Optional[AttentionMetadata] = None,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        # Custom forward logic
        hidden_states = self.model(
            input_ids, 
            positions, 
            kv_caches,
            attn_metadata, 
            intermediate_tensors,
            inputs_embeds
        )
        return hidden_states

备注

完整的实现参考请见:vllm_ascend/models/deepseek_v2.py

第2步:使用 vLLM 中的 ModelRegistry 插件注册自定义模型#

vllm 提供了一种插件机制,可用于注册外部实现的模型,而无需修改其代码库。

要集成你在 vllm_ascend/models/ 目录下实现的模型:

  1. 使用相对导入在 vllm_ascend/models/__init__.py 中导入你的模型实现。

  2. 通过 vllm.ModelRegistry.register_model() 函数注册模型包装类。

参考注册模板(在 vllm_ascend/models/__init__.py 注册新模型的示例):

from vllm import ModelRegistry

def register_model():
    from .custom_model import CustomModelForCausalLM        # New custom model
    from .deepseek_v2 import ModifiedDeepseekV2ForCausalLM  # Customized Deepseek

    # For NEW architectures: Register with unique name
    ModelRegistry.register_model(
        "CustomModelForCausalLM",  # Must match config.json's 'architectures'
        "vllm_ascend.models.custom_model:CustomModelForCausalLM"
    )

    # For MODIFIED architectures: Use original name
    ModelRegistry.register_model(
        "DeepseekV2ForCausalLM",   # Original architecture identifier in vLLM
        "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM  "
    )

备注

vllm.ModelRegistry.register_model() 的第一个参数表示唯一的架构标识符,这个标识符必须与模型的 config.json 文件中的 architectures 匹配。

{
  "architectures": [
    "CustomModelForCausalLM"
  ],
}

第 3 步:验证#

案例 1:重载已有的 vLLM 模型架构#

如果你基于 vllm 的现有实现注册了一个自定义的模型架构(覆盖了 vllm 的原始类),在执行 vllm 的离线/在线推理(无论使用哪个模型)时,你会看到类似于 vllm/models_executor/models/registry.py 输出的警告日志。

Model architecture DeepseekV2ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend/models/deepseek_v2:CustomDeepseekV2ForCausalLM.

案例2:注册新模型架构#

如果你注册了 vllm 中不存在的新模型架构(创建一个全新的类),当前日志默认不会提供明确的确认信息。建议在 vllm/models_executor/models/registry.py 文件中的 register_model 方法末尾添加如下日志语句。

logger.info(f"model_arch: {model_arch} has been registered here!")

添加这一行之后,当你运行 vllm 离线/在线推理(使用任何模型)时,将会看到如下确认日志。

model_arch: CustomModelForCausalLM has been registered here!

该日志输出确认了你的新模型架构已成功在 vllm 中注册。

第4步:测试#

在添加新模型后,我们应对该模型进行基本功能测试(离线/在线推理)、准确率测试和性能基准测试。

更多详情请见:

第5步:更新支持的模型文档#

最后,如果以上所有步骤都已完成,你应该将新模型添加到我们的支持的模型文档中。