-
Notifications
You must be signed in to change notification settings - Fork 263
Description
Proposal Summary
Feature Request: Simplified Model Architecture Integration via Class-Based API
Summary
Introduce a class-based API in Olive that mirrors the modularity and extensibility of MLX’s Model class (e.g., [gpt_oss.py](https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/gpt_oss.py)), enabling developers to define and register new model architectures with minimal boilerplate. This would significantly reduce the complexity currently required in ONNX Runtime GenAI’s [builder.py](https://github.com/microsoft/onnxruntime-genai/blob/main/src/python/py/models/builder.py).
Problem Statement
The current ONNX Runtime GenAI workflow for integrating custom models involves:
- Deep coupling with ONNX IR graph construction
- Manual configuration of tensor shapes, types, and EP-specific attributes
- Extensive boilerplate for rotary embeddings, MoE routing, LayerNorm variants, and quantization
This makes experimentation and extension prohibitively complex for researchers and developers.
MLX Reference Implementation
MLX’s architecture registry (see [mlx_lm/models](https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models)) allows:
- Plug-and-play model definitions via simple Python classes
- Clean separation of model logic (e.g.,
TransformerBlock,AttentionBlock,MLPBlock) - Dynamic cache handling and mask generation
- Easy integration of MoE, SwiGLU, rotary embeddings, and sliding window attention
Example: GptOssMoeModel encapsulates all architectural components and can be instantiated with a single ModelArgs dataclass.
Proposed Solution
Introduce a high-level Olive API with:
- A
BaseModelclass that abstracts ONNX IR graph creation - Decorators or registration hooks for defining attention, MLP, and embedding blocks
- Automatic shape/type inference from HuggingFace configs
- Optional integration with Olive’s quantization and export pipelines
This would allow users to define models like:
class MyTransformer(BaseModel):
def __init__(self, config):
self.attn = MultiHeadAttention(config)
self.mlp = SwiGLU(config)
...And register them with:
olive.register_model("my_transformer", MyTransformer)Impact
- Reduces onboarding friction for model developers
- Encourages experimentation with novel architectures
- Aligns Olive with modern modular frameworks like MLX and HuggingFace
- Enables community contributions of reusable model blocks
- Resolves issues of new models new architectures
- Create parity with MLX and Model Support
What component(s) does this request affect?
- OliveModels
- OliveSystems
- OliveEvaluator
- Metrics
- Engine
- Passes
- Other