Skip to content

Commit 1c7e30d

Browse files
committed
[UPDATE]: add configuration model
1 parent 65e4649 commit 1c7e30d

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

src/config/settings.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,35 +28,35 @@ class Settings(BaseSettings):
2828
HOST: str = "0.0.0.0"
2929
PORT: int = 7860
3030
WORKERS: int = 1
31-
RELOAD: bool = False # Auto-reload on code changes (dev only)
31+
RELOAD: bool = False
3232

3333
# Model Configuration
3434
MODEL_CONFIG_PATH: str = "src/config/models.yaml"
3535
MODEL_CACHE_DIR: str = "./model_cache"
36-
PRELOAD_MODELS: bool = True # Load all models at startup
36+
PRELOAD_MODELS: bool = True
3737

3838
# Request Limits
39-
MAX_TEXT_LENGTH: int = 32000 # Maximum characters per text
40-
MAX_BATCH_SIZE: int = 100 # Maximum texts per batch request
41-
REQUEST_TIMEOUT: int = 30 # Request timeout in seconds
39+
MAX_TEXT_LENGTH: int = 32000
40+
MAX_BATCH_SIZE: int = 100
41+
REQUEST_TIMEOUT: int = 30
4242

4343
# Logging
4444
LOG_LEVEL: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
45-
LOG_FILE: bool = False # Write logs to file
45+
LOG_FILE: bool = True # Write logs to file
4646
LOG_DIR: str = "logs"
4747

48-
# CORS (if needed for web frontends)
4948
CORS_ENABLED: bool = False
5049
CORS_ORIGINS: list[str] = ["*"]
5150

5251
# Model Settings
52+
DEVICE: str = "cpu" # "cpu" or "cuda
5353
TRUST_REMOTE_CODE: bool = True # For models requiring remote code
5454

5555
model_config = SettingsConfigDict(
5656
env_file=".env",
5757
env_file_encoding="utf-8",
5858
case_sensitive=True,
59-
extra="ignore", # Ignore extra fields in .env
59+
extra="ignore",
6060
)
6161

6262
@property
@@ -86,10 +86,8 @@ def validate_paths(self) -> None:
8686
f"Model configuration file not found: {self.MODEL_CONFIG_PATH}"
8787
)
8888

89-
# Create cache directory if it doesn't exist
9089
Path(self.MODEL_CACHE_DIR).mkdir(parents=True, exist_ok=True)
9190

92-
# Create log directory if logging to file
9391
if self.LOG_FILE:
9492
Path(self.LOG_DIR).mkdir(parents=True, exist_ok=True)
9593

src/models/embeddings/dense.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sentence_transformers import SentenceTransformer
1010
from loguru import logger
1111

12+
from src.config.settings import get_settings
1213
from src.core.base import BaseEmbeddingModel
1314
from src.core.config import ModelConfig
1415
from src.core.exceptions import ModelLoadError, EmbeddingGenerationError
@@ -36,6 +37,7 @@ def __init__(self, config: ModelConfig):
3637
"""
3738
super().__init__(config)
3839
self.model: Optional[SentenceTransformer] = None
40+
self.settings = get_settings()
3941

4042
def load(self) -> None:
4143
"""
@@ -52,7 +54,9 @@ def load(self) -> None:
5254

5355
try:
5456
self.model = SentenceTransformer(
55-
self.config.name, device="cpu", trust_remote_code=True
57+
self.config.name,
58+
device=self.settings.DEVICE,
59+
trust_remote_code=self.settings.TRUST_REMOTE_CODE
5660
)
5761
self._loaded = True
5862
logger.success(f"✓ Loaded dense model: {self.model_id}")

src/models/embeddings/sparse.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sentence_transformers import SparseEncoder
1010
from loguru import logger
1111

12+
from src.config.settings import get_settings
1213
from src.core.base import BaseEmbeddingModel
1314
from src.core.config import ModelConfig
1415
from src.core.exceptions import ModelLoadError, EmbeddingGenerationError
@@ -36,6 +37,7 @@ def __init__(self, config: ModelConfig):
3637
"""
3738
super().__init__(config)
3839
self.model: Optional[SparseEncoder] = None
40+
self.settings = get_settings()
3941

4042
def load(self) -> None:
4143
"""
@@ -51,7 +53,11 @@ def load(self) -> None:
5153
logger.info(f"Loading sparse embedding model: {self.config.name}")
5254

5355
try:
54-
self.model = SparseEncoder(self.config.name)
56+
self.model = SparseEncoder(
57+
self.config.name,
58+
device=self.settings.DEVICE,
59+
trust_remote_code=self.settings.TRUST_REMOTE_CODE,
60+
)
5561
self._loaded = True
5662
logger.success(f"✓ Loaded sparse model: {self.model_id}")
5763

0 commit comments

Comments
 (0)