Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions google/genai/tests/tunings/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
from .. import pytest_helper
import pytest


VERTEX_HTTP_OPTIONS = {
'api_version': 'v1beta1',
'base_url': 'https://us-central1-autopush-aiplatform.sandbox.googleapis.com/',
}

evaluation_config=genai_types.EvaluationConfig(
metrics=[
genai_types.Metric(name="bleu", prompt_template="test prompt template")
Expand Down Expand Up @@ -158,6 +164,26 @@
),
exception_if_mldev="vertex_dataset_resource parameter is not supported in Gemini API.",
),
pytest_helper.TestTableItem(
name="test_tune_distillation",
parameters=genai_types.CreateTuningJobParameters(
base_model="meta/llama3_1@llama-3.1-8b-instruct",
training_dataset=genai_types.TuningDataset(
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
),
config=genai_types.CreateTuningJobConfig(
method="DISTILLATION",
base_teacher_model="deepseek-ai/deepseek-v3.1-maas",
epoch_count=20,
validation_dataset=genai_types.TuningValidationDataset(
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
),
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder",
http_options=VERTEX_HTTP_OPTIONS,
),
),
exception_if_mldev="parameter is not supported in Gemini API.",
),
]

pytestmark = pytest_helper.setup(
Expand Down
117 changes: 117 additions & 0 deletions google/genai/tunings.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,24 @@ def _CreateTuningJobConfig_to_mldev(
if getv(from_object, ['beta']) is not None:
raise ValueError('beta parameter is not supported in Gemini API.')

if getv(from_object, ['base_teacher_model']) is not None:
raise ValueError(
'base_teacher_model parameter is not supported in Gemini API.'
)

if getv(from_object, ['tuned_teacher_model_source']) is not None:
raise ValueError(
'tuned_teacher_model_source parameter is not supported in Gemini API.'
)

if getv(from_object, ['sft_loss_weight_multiplier']) is not None:
raise ValueError(
'sft_loss_weight_multiplier parameter is not supported in Gemini API.'
)

if getv(from_object, ['output_uri']) is not None:
raise ValueError('output_uri parameter is not supported in Gemini API.')

return to_object


Expand Down Expand Up @@ -246,6 +264,16 @@ def _CreateTuningJobConfig_to_vertex(
),
)

elif discriminator == 'DISTILLATION':
if getv(from_object, ['validation_dataset']) is not None:
setv(
parent_object,
['distillationSpec'],
_TuningValidationDataset_to_vertex(
getv(from_object, ['validation_dataset']), to_object, root_object
),
)

if getv(from_object, ['tuned_model_display_name']) is not None:
setv(
parent_object,
Expand Down Expand Up @@ -275,6 +303,14 @@ def _CreateTuningJobConfig_to_vertex(
getv(from_object, ['epoch_count']),
)

elif discriminator == 'DISTILLATION':
if getv(from_object, ['epoch_count']) is not None:
setv(
parent_object,
['distillationSpec', 'hyperParameters', 'epochCount'],
getv(from_object, ['epoch_count']),
)

discriminator = getv(root_object, ['config', 'method'])
if discriminator is None:
discriminator = 'SUPERVISED_FINE_TUNING'
Expand All @@ -298,6 +334,14 @@ def _CreateTuningJobConfig_to_vertex(
getv(from_object, ['learning_rate_multiplier']),
)

elif discriminator == 'DISTILLATION':
if getv(from_object, ['learning_rate_multiplier']) is not None:
setv(
parent_object,
['distillationSpec', 'hyperParameters', 'learningRateMultiplier'],
getv(from_object, ['learning_rate_multiplier']),
)

discriminator = getv(root_object, ['config', 'method'])
if discriminator is None:
discriminator = 'SUPERVISED_FINE_TUNING'
Expand All @@ -317,6 +361,14 @@ def _CreateTuningJobConfig_to_vertex(
getv(from_object, ['export_last_checkpoint_only']),
)

elif discriminator == 'DISTILLATION':
if getv(from_object, ['export_last_checkpoint_only']) is not None:
setv(
parent_object,
['distillationSpec', 'exportLastCheckpointOnly'],
getv(from_object, ['export_last_checkpoint_only']),
)

discriminator = getv(root_object, ['config', 'method'])
if discriminator is None:
discriminator = 'SUPERVISED_FINE_TUNING'
Expand All @@ -336,6 +388,14 @@ def _CreateTuningJobConfig_to_vertex(
getv(from_object, ['adapter_size']),
)

elif discriminator == 'DISTILLATION':
if getv(from_object, ['adapter_size']) is not None:
setv(
parent_object,
['distillationSpec', 'hyperParameters', 'adapterSize'],
getv(from_object, ['adapter_size']),
)

if getv(from_object, ['batch_size']) is not None:
raise ValueError('batch_size parameter is not supported in Vertex AI.')

Expand Down Expand Up @@ -365,6 +425,16 @@ def _CreateTuningJobConfig_to_vertex(
),
)

elif discriminator == 'DISTILLATION':
if getv(from_object, ['evaluation_config']) is not None:
setv(
parent_object,
['distillationSpec', 'evaluationConfig'],
_EvaluationConfig_to_vertex(
getv(from_object, ['evaluation_config']), to_object, root_object
),
)

if getv(from_object, ['labels']) is not None:
setv(parent_object, ['labels'], getv(from_object, ['labels']))

Expand All @@ -375,6 +445,30 @@ def _CreateTuningJobConfig_to_vertex(
getv(from_object, ['beta']),
)

if getv(from_object, ['base_teacher_model']) is not None:
setv(
parent_object,
['distillationSpec', 'baseTeacherModel'],
getv(from_object, ['base_teacher_model']),
)

if getv(from_object, ['tuned_teacher_model_source']) is not None:
setv(
parent_object,
['distillationSpec', 'tunedTeacherModelSource'],
getv(from_object, ['tuned_teacher_model_source']),
)

if getv(from_object, ['sft_loss_weight_multiplier']) is not None:
setv(
parent_object,
['distillationSpec', 'hyperParameters', 'sftLossWeightMultiplier'],
getv(from_object, ['sft_loss_weight_multiplier']),
)

if getv(from_object, ['output_uri']) is not None:
setv(parent_object, ['outputUri'], getv(from_object, ['output_uri']))

return to_object


Expand Down Expand Up @@ -920,6 +1014,14 @@ def _TuningDataset_to_vertex(
getv(from_object, ['gcs_uri']),
)

elif discriminator == 'DISTILLATION':
if getv(from_object, ['gcs_uri']) is not None:
setv(
parent_object,
['distillationSpec', 'promptDatasetUri'],
getv(from_object, ['gcs_uri']),
)

discriminator = getv(root_object, ['config', 'method'])
if discriminator is None:
discriminator = 'SUPERVISED_FINE_TUNING'
Expand All @@ -939,6 +1041,14 @@ def _TuningDataset_to_vertex(
getv(from_object, ['vertex_dataset_resource']),
)

elif discriminator == 'DISTILLATION':
if getv(from_object, ['vertex_dataset_resource']) is not None:
setv(
parent_object,
['distillationSpec', 'promptDatasetUri'],
getv(from_object, ['vertex_dataset_resource']),
)

if getv(from_object, ['examples']) is not None:
raise ValueError('examples parameter is not supported in Vertex AI.')

Expand Down Expand Up @@ -1066,6 +1176,13 @@ def _TuningJob_from_vertex(
getv(from_object, ['preferenceOptimizationSpec']),
)

if getv(from_object, ['distillationSpec']) is not None:
setv(
to_object,
['distillation_spec'],
getv(from_object, ['distillationSpec']),
)

if getv(from_object, ['tuningDataStats']) is not None:
setv(
to_object, ['tuning_data_stats'], getv(from_object, ['tuningDataStats'])
Expand Down
Loading