-
Notifications
You must be signed in to change notification settings - Fork 11
Add Gemma3 code #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Gemma3 code #36
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @monatis, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello! Gemini here, providing a summary of this pull request to help everyone get oriented. This PR introduces support for the Gemma 3 model into the jaxgarden library. The core changes involve adding the necessary model components (like attention, MLP, RMS norm, and the main causal language model class), integrating these new components into the library's structure, and providing an example script to demonstrate inference. Basic unit tests for some of the core components are also included.
Highlights
- New Model Implementation: Adds the full implementation of the Gemma 3 model architecture in JAX/Flax NNX, including
Gemma3Config,Gemma3RMSNorm,Gemma3RotaryEmbedding,Gemma3Attention(with GQA and sliding window support),Gemma3MLP,Gemma3DecoderLayer, and the mainGemma3ForCausalLMclass. - Library Integration: Integrates the new Gemma 3 model components by adding necessary imports and exports in
jaxgarden/__init__.pyandjaxgarden/models/__init__.pyto make them accessible within the library. - Inference Example: Includes a new example script (
examples/gemma3_inference_example.py) demonstrating how to initialize the Gemma 3 model with a specific configuration, load a tokenizer (using Hugging Face Transformers), prepare input, and perform text generation with KV caching. - Unit Tests: Adds a new test file (
tests/models/test_gemma3.py) with initial unit tests for key Gemma 3 components likeGemma3Config,Gemma3RMSNorm, andGemma3RotaryEmbeddingto verify basic functionality and configuration validation.
Changelog
Click here to see the changelog
- examples/gemma3_inference_example.py
- Added a new file.
- Implemented a script demonstrating Gemma3 model initialization and text generation using KV caching.
- jaxgarden/init.py
- Added imports for Gemma3 model components (
Gemma3Attention,Gemma3Config,Gemma3ForCausalLM,Gemma3MLP,Gemma3RMSNorm,Gemma3RotaryEmbedding). - Added Gemma3 components to the
__all__list for public export.
- Added imports for Gemma3 model components (
- jaxgarden/models/init.py
- Added imports for Gemma3 model components.
- Added Gemma3 components to the
__all__list.
- jaxgarden/models/base.py
- Added
# type: ignorecomment to thehuggingface_hub.snapshot_downloadimport to suppress type-checking errors.
- Added
- jaxgarden/models/gemma3.py
- Added a new file containing the full Gemma3 model implementation.
- Defined
Gemma3Configdataclass with detailed model parameters and validation. - Implemented
Gemma3RMSNormmodule. - Implemented
Gemma3RotaryEmbeddingmodule with support for various scaling types (linear, dynamic, yarn, longrope, llama3) and local attention. - Implemented
Gemma3Attentionmodule supporting Grouped Query Attention (GQA), sliding window attention, and attention logits soft capping. - Implemented
Gemma3MLPmodule with GeGLU activation. - Implemented
Gemma3DecoderLayercombining attention and MLP with pre/post normalization. - Implemented
Gemma3ForCausalLMas the main model class, handling embeddings, layer stacking, final normalization, and logits calculation with weight tying and soft capping. - Included logic for handling position IDs and attention masks, including extending masks for KV caching.
- tests/models/test_gemma3.py
- Added a new file for Gemma3 unit tests.
- Added
test_gemma3_configto verify configuration initialization. - Added
test_gemma3_rms_normto test the RMSNorm layer. - Added
test_gemma3_rotary_embeddingto test the RoPE module.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for the Gemma3 model, which is a significant addition to the jaxgarden library. The implementation includes the core model components, an inference example, and initial unit tests. The code is generally well-structured and the configuration options for Gemma3 are comprehensive.
I've identified a few areas that need attention, including a critical issue related to attention masking that affects model correctness, a high-severity issue in RoPE configuration validation, and some medium-severity points regarding tokenizer choice, attention type determination, and test coverage. Addressing these will ensure the Gemma3 implementation is robust and correct.
Summary of Findings
- Critical Attention Masking Issue: The model's causal and sliding window attention masks are not being applied correctly in
Gemma3ForCausalLM, potentially leading to incorrect model behavior as future tokens might be attended to. - RoPE Configuration Validation: The validation for the
"factor"parameter inGemma3Configfor RoPE scaling is overly general and may incorrectly flag valid configurations for certain RoPE types. - Tokenizer Choice in Example: The Gemma3 inference example uses a
gemma-2btokenizer. Clarification is needed on whether this is appropriate or a placeholder. - Attention Type Determination Logic: There's a potential redundancy/conflict in how local vs. global attention is determined between
Gemma3Config,Gemma3DecoderLayer, andGemma3Attention. - Unit Test Coverage: Key new components like
Gemma3Attention,Gemma3MLP,Gemma3DecoderLayer, andGemma3ForCausalLMlack unit tests. Adding these would significantly improve confidence in the implementation.
Merge Readiness
This pull request makes a valuable contribution by adding Gemma3 support. However, due to the critical issue identified with attention masking, along with other high and medium severity concerns, I recommend that these changes should not be merged in their current state. Addressing the identified issues, particularly the masking problem, is essential for the correctness of the Gemma3 model. Once these are resolved, the PR will be in much better shape. As an AI reviewer, I am not authorized to approve pull requests; please ensure further review and approval from authorized maintainers after the suggested changes are made.
| # Prepare attention_mask (padding mask) | ||
| # This mask should cover the entire kv_seq_len for keys/values | ||
| if attention_mask is not None: | ||
| # The input attention_mask corresponds to input_ids [B, S] | ||
| # We need to extend it for the cached keys/values. | ||
| # Assume cached tokens were valid. | ||
| if cache is not None: | ||
| # Create mask for cached part (all True) | ||
| cache_mask = jnp.ones((batch_size, cache_len), dtype=jnp.bool_) | ||
| # Concatenate with the input mask | ||
| padding_mask_2d = jnp.concatenate([cache_mask, attention_mask], axis=1) | ||
| else: | ||
| padding_mask_2d = attention_mask # No cache, use input mask directly | ||
|
|
||
| # Final shape check for the 2D padding mask | ||
| if padding_mask_2d.shape != (batch_size, kv_seq_len): | ||
| raise ValueError( | ||
| f"Constructed 2D padding mask shape {padding_mask_2d.shape} " | ||
| f"does not match expected ({batch_size}, {kv_seq_len})" | ||
| ) | ||
| else: | ||
| # If no mask provided, assume all tokens are valid | ||
| padding_mask_2d = jnp.ones((batch_size, kv_seq_len), dtype=jnp.bool_) | ||
|
|
||
| # Reshape the 2D boolean padding mask to 4D log-mask for attention calculation | ||
| # Shape: [B, 1, 1, kv_len]. Log-mask: 0.0 for attend, -inf for ignore. | ||
| # Use the validated `padding_mask_2d` here | ||
| attn_mask_4d = jnp.where( | ||
| padding_mask_2d[:, None, None, :], 0.0, jnp.finfo(self.config.dtype).min | ||
| ) | ||
|
|
||
| # --- Pass through Decoder Layers --- | ||
| next_cache_list = [] | ||
| for i, layer in enumerate(self.layers): | ||
| layer_cache = cache[i] if cache is not None else None | ||
| hidden_states, updated_layer_cache = layer( | ||
| hidden_states, | ||
| position_ids, | ||
| attention_mask=attn_mask_4d, # Pass the 4D log-mask [B, 1, 1, kv_len] | ||
| cache=layer_cache, | ||
| deterministic=deterministic, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current attention masking logic in Gemma3ForCausalLM.__call__ appears to inadvertently disable the causal and sliding window masks that are intended to be generated within Gemma3Attention.
Here's a breakdown:
Gemma3ForCausalLM(lines 696-701) constructsattn_mask_4d. This mask is derived from the 2Dpadding_mask_2dand represents an additive padding mask ([B, 1, 1, kv_len]). It does not include causal or sliding window information.- This
attn_mask_4dis passed toGemma3DecoderLayer, and then toGemma3Attention.__call__asattention_mask. - In
Gemma3Attention.__call__(lines 468-470), the conditionif attention_mask is not None and attention_mask.ndim == 4:is met. Consequently, the code executesattn_weights = attn_weights + attention_mask.astype(self.config.dtype). - This bypasses the
elseblock (lines 472-491) inGemma3Attention, which is responsible for callingself._make_sliding_window_mask()to generate the crucial causal/sliding window mask (attn_internal_mask) and combine it with any 2D padding mask.
As a result, the model currently operates without enforcing causality or the configured sliding window attention, which is a critical correctness issue.
To fix this, Gemma3ForCausalLM should likely pass the 2D boolean padding_mask_2d to the attention layer, allowing Gemma3Attention to correctly combine it with its internally generated causal/sliding window mask. The attention_mask parameter type hint in Gemma3Attention also suggests it expects a boolean mask if its internal logic is to be used.
A possible approach:
- Modify
Gemma3ForCausalLM.__call__to passpadding_mask_2d(the 2D boolean mask) as theattention_masktolayer(). - Adjust
Gemma3Attention.__call__to expect this 2D boolean mask when the 4D path isn't taken, or ensure the 4D mask passed already incorporates all necessary masking aspects (causal, sliding, padding).
| if rope_type != "default" and "factor" not in self.rope_scaling: | ||
| raise ValueError(f"RoPE scaling type {rope_type} requires 'factor' parameter") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation check for the "factor" parameter in rope_scaling seems a bit too general. It currently checks if rope_type != "default" and "factor" not in self.rope_scaling:. However, RoPE types like "dynamic", "yarn", "longrope", and "llama3" have their own specific parameter requirements (e.g., original_max_position_embeddings, low_freq_factor) and might not use a generic "factor" key. This could lead to an incorrect ValueError for these types.
Could this check be refined to only enforce the presence of "factor" for RoPE types that explicitly require it, like "linear"? Other types seem to have their specific checks later in the __post_init__ method.
| if rope_type != "default" and "factor" not in self.rope_scaling: | |
| raise ValueError(f"RoPE scaling type {rope_type} requires 'factor' parameter") | |
| EXPECTS_FACTOR = ["linear"] # Add other types if they also strictly need a generic 'factor' | |
| if rope_type in EXPECTS_FACTOR and "factor" not in self.rope_scaling: | |
| raise ValueError(f"RoPE scaling type {rope_type} requires 'factor' parameter") |
|
|
||
| # Load tokenizer | ||
| print("Loading tokenizer...") | ||
| tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example script for Gemma3 is using a tokenizer from "google/gemma-2b". Is this intentional? While Gemma3 might be compatible with Gemma2's tokenizer, using a tokenizer specifically intended or validated for Gemma3 (if available, or once available) would be preferable to ensure optimal performance and correctness. Could you clarify if this is a placeholder or the recommended tokenizer for Gemma3 currently?
| attention_type: str, | ||
| rngs: nnx.Rngs, | ||
| ): | ||
| super().__init__() | ||
| self.config = config | ||
| self.layer_idx = layer_idx | ||
| self.attention_type = attention_type | ||
| self.hidden_size = config.hidden_size | ||
| self.num_heads = config.num_attention_heads | ||
| self.num_kv_heads = config.num_key_value_heads | ||
| self.head_dim = config.head_dim | ||
|
|
||
| self.num_key_value_groups = self.num_heads // self.num_kv_heads | ||
| self.scaling = config.query_pre_attn_scalar | ||
| self.attn_logit_soft_cap = config.attn_logit_soft_cap | ||
| self.is_local_attn = ( | ||
| self.config.layer_types[layer_idx] == "sliding_attention" | ||
| if self.config.layer_types | ||
| else (layer_idx % 2 != 0) and (config.sliding_window is not None) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a slight redundancy or potential for conflicting logic in determining if a layer uses local/sliding attention versus full attention.
Gemma3Config.__post_init__(lines 139-142) generatesself.layer_typesbased onself.sliding_window_pattern.Gemma3DecoderLayer.__init__(line 570) defines anattention_typestring ("global"or"local") based onlayer_idx % 2(even/odd alternation).Gemma3Attention.__init__receives thisattention_typestring as a parameter (line 306).- However,
Gemma3Attention.is_local_attn(lines 321-325) correctly determines its behavior primarily fromself.config.layer_types[layer_idx], which should be the single source of truth derived fromsliding_window_pattern.
The attention_type parameter passed to Gemma3Attention isn't directly used if self.config.layer_types is present (which it always should be after config initialization). If sliding_window_pattern is not 2, the logic in Gemma3DecoderLayer for attention_type would differ from config.layer_types.
Consider removing the attention_type: str parameter from Gemma3Attention.__init__ and relying solely on config.layer_types[self.layer_idx] within Gemma3Attention to determine is_local_attn. This would simplify the logic and ensure consistency with Gemma3Config.
| """Tests for Gemma3 model.""" | ||
|
|
||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| from flax import nnx | ||
|
|
||
| from jaxgarden.models.gemma3 import ( | ||
| Gemma3Attention, | ||
| Gemma3Config, | ||
| Gemma3DecoderLayer, | ||
| Gemma3ForCausalLM, | ||
| Gemma3MLP, | ||
| Gemma3RMSNorm, | ||
| Gemma3RotaryEmbedding, | ||
| ) | ||
|
|
||
|
|
||
| def test_gemma3_config(): | ||
| """Test Gemma3Config initialization and validation.""" | ||
| config = Gemma3Config() | ||
| assert config.num_attention_heads % config.num_key_value_heads == 0 | ||
| assert config.head_dim == 256 # Default head_dim | ||
|
|
||
|
|
||
| def test_gemma3_rms_norm(): | ||
| """Test RMSNorm layer.""" | ||
| rng = nnx.Rngs(0) | ||
| dim = 32 | ||
| batch_size = 2 | ||
| seq_len = 4 | ||
|
|
||
| norm = Gemma3RMSNorm(dim, eps=1e-6, rngs=rng) | ||
| x = jnp.ones((batch_size, seq_len, dim), dtype=jnp.float32) | ||
| out = norm(x) | ||
|
|
||
| assert out.shape == x.shape | ||
| # Output should be normalized | ||
| variance = jnp.mean(jnp.square(out), axis=-1) | ||
| np.testing.assert_allclose(variance, 1.0, rtol=1e-5) | ||
|
|
||
|
|
||
| def test_gemma3_rotary_embedding(): | ||
| """Test RotaryEmbedding module.""" | ||
| dim = 32 | ||
| batch_size = 2 | ||
| seq_len = 4 | ||
| num_heads = 2 | ||
|
|
||
| rope = Gemma3RotaryEmbedding(dim=dim, max_position_embeddings=8192) | ||
| x = jnp.ones((batch_size, seq_len, num_heads, dim), dtype=jnp.float32) | ||
| position_ids = jnp.arange(seq_len, dtype=jnp.int32)[None, :] # [1, seq_len] | ||
| position_ids = jnp.broadcast_to(position_ids, (batch_size, seq_len)) | ||
|
|
||
| out = rope(x, position_ids) | ||
| assert out.shape == (batch_size, seq_len, num_heads, dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The initial set of tests for Gemma3 components is a good start, covering Gemma3Config, Gemma3RMSNorm, and Gemma3RotaryEmbedding.
However, to ensure the robustness of the new Gemma3 implementation, it would be highly beneficial to add tests for the other core components introduced in this PR, such as:
Gemma3AttentionGemma3MLPGemma3DecoderLayerGemma3ForCausalLM(especially an end-to-end forward pass)
These tests should ideally verify not just shapes but also basic numerical correctness, perhaps against simple, known inputs or reference values if possible. Expanding test coverage will help catch regressions and confirm the functionality of these key modules.
This pull request introduces support for the Gemma3 model in the
jaxgardenlibrary, including its implementation, integration, and testing. The most significant changes involve adding the Gemma3 model components, creating an inference example script, and implementing unit tests for the model.Gemma3 Model Implementation and Integration:
Gemma3Attention,Gemma3Config,Gemma3ForCausalLM,Gemma3MLP,Gemma3RMSNorm, andGemma3RotaryEmbedding) to the library'smodelsmodule and updated the__init__.pyfiles to include these components in the public API. [1] [2] [3] [4]Example Script for Gemma3 Inference:
examples/gemma3_inference_example.py, which demonstrates how to use the Gemma3 model for causal language modeling, including model initialization, tokenization, and text generation.Unit Tests for Gemma3 Model:
tests/models/test_gemma3.py, which includes tests for Gemma3 components such asGemma3Config,Gemma3RMSNorm, andGemma3RotaryEmbeddingto ensure proper functionality and validation.Minor Improvements:
jaxgarden/models/base.pyfile to include a type ignore comment for thehuggingface_hub.snapshot_downloadimport to suppress type-checking errors.