Skip to content

Commit 98673c2

Browse files
committed
fix: add modle types
1 parent c9b562e commit 98673c2

File tree

11 files changed

+209
-429
lines changed

11 files changed

+209
-429
lines changed

src/AudioModule.ts

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,10 @@ export class AudioModule extends BaseModule {
3131

3232
const { ttsModels = [], lazyLoad = false } = options;
3333

34-
// Check hardware capabilities
35-
const capabilities = await this.webgpuChecker.check();
36-
3734
this.progressTracker.update({
3835
status: 'init',
3936
type: 'audio_module',
40-
message: `Hardware check: WebGPU ${capabilities.isWebGPUSupported ? 'available' : 'not available'}`
37+
// message: `Hardware check: WebGPU ${capabilities.isWebGPUSupported ? 'available' : 'not available'}` FIX ME:
4138
});
4239

4340
// Load first model if specified and not using lazy loading
@@ -92,26 +89,13 @@ export class AudioModule extends BaseModule {
9289
this.activeModel = model;
9390

9491
try {
95-
// Check hardware capabilities
96-
const capabilities = await this.webgpuChecker.check();
97-
98-
// Get optimal config
99-
const config = await this.getOptimalDeviceConfig();
100-
101-
// For TTS models, always use CPU in Node.js environment or proper fallback
102-
const isNode = this.isNodeEnvironment();
103-
104-
// Determine device explicitly
105-
const device = isNode ? "cpu" :
106-
!capabilities.isWebGPUSupported ? "wasm" : "webgpu";
107-
10892
// Initial progress message
10993
this.progressTracker.update({
11094
status: 'loading',
11195
type: 'tts_model',
11296
progress: 0,
11397
percentComplete: 0,
114-
message: `Loading TTS model ${model} (device="${device}", dtype="${config.dtype}")`
98+
// message: `Loading TTS model ${model} (device="${device}", dtype="${config.dtype}")` FIX ME:
11599
});
116100

117101
// Load the model directly by name
@@ -126,8 +110,8 @@ export class AudioModule extends BaseModule {
126110
...progress
127111
});
128112
},
129-
device: device,
130-
dtype: config.dtype
113+
device: 'webgpu', // device, FIX ME:
114+
dtype: 'fp16', // config.dtype FIX ME:
131115
});
132116

133117
// Register the model as loaded
@@ -281,9 +265,6 @@ export class AudioModule extends BaseModule {
281265
this.activeModel = null;
282266
}
283267

284-
// Try to trigger garbage collection
285-
this.triggerGC();
286-
287268
this.progressTracker.update({
288269
status: 'offloaded',
289270
type: 'tts_model',

src/BaseModule.ts

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
* Provides common structure for all modules like generation, embeddings, etc.
44
*/
55

6-
import { WebGPUChecker } from './WebGPUChecker';
76
import { ProgressTracker } from './ProgressTracker';
8-
import { tryGarbageCollection, EnvironmentInfo } from './utils';
9-
import { DeviceConfig } from './types';
7+
import { DeviceConfig } from './Models';
108

119
/**
1210
* Base interface for TinyLM modules
@@ -26,8 +24,6 @@ export interface TinyLMModule {
2624
export abstract class BaseModule implements TinyLMModule {
2725
protected tinyLM: any;
2826
protected progressTracker: ProgressTracker;
29-
protected webgpuChecker: WebGPUChecker;
30-
protected environment: EnvironmentInfo;
3127

3228
/**
3329
* Create a new module
@@ -36,8 +32,6 @@ export abstract class BaseModule implements TinyLMModule {
3632
constructor(tinyLM: any) {
3733
this.tinyLM = tinyLM;
3834
this.progressTracker = tinyLM.getProgressTracker();
39-
this.webgpuChecker = tinyLM.getWebGPUChecker();
40-
this.environment = tinyLM.getEnvironment();
4135
}
4236

4337
/**
@@ -48,41 +42,4 @@ export abstract class BaseModule implements TinyLMModule {
4842
async init(options: Record<string, any> = {}): Promise<void> {
4943
// Base initialization code
5044
}
51-
52-
/**
53-
* Get optimal device configuration based on capabilities
54-
* @returns {Promise<DeviceConfig>} Device configuration
55-
*/
56-
protected async getOptimalDeviceConfig(): Promise<DeviceConfig> {
57-
const capabilities = await this.webgpuChecker.check();
58-
return this.webgpuChecker.getOptimalConfig();
59-
}
60-
61-
/**
62-
* Trigger garbage collection if available
63-
*/
64-
protected triggerGC(): void {
65-
try {
66-
// Use the utility function
67-
tryGarbageCollection();
68-
} catch (error) {
69-
// Ignore errors
70-
}
71-
}
72-
73-
/**
74-
* Check if running in a Node.js environment
75-
* @returns {boolean} True if in Node.js environment
76-
*/
77-
protected isNodeEnvironment(): boolean {
78-
return this.environment.isNode;
79-
}
80-
81-
/**
82-
* Check if running in a browser environment
83-
* @returns {boolean} True if in browser environment
84-
*/
85-
protected isBrowserEnvironment(): boolean {
86-
return this.environment.isBrowser;
87-
}
8845
}

src/EmbeddingsModule.ts

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -152,25 +152,22 @@ export class EmbeddingsModule extends BaseModule {
152152
});
153153

154154
try {
155-
// Check WebGPU capabilities and get optimal config
156-
const capabilities = await this.webgpuChecker.check();
157-
const config = await this.getOptimalDeviceConfig();
158-
159155
let modelInfo: EmbeddingModelInfo;
160156

161157
// Strategy based on WebGPU availability
162-
if (capabilities.isWebGPUSupported) {
158+
// if (capabilities.isWebGPUSupported) { FIX ME:
159+
if (true) {
163160
// WebGPU is available - try direct tokenizer+model approach first
164161
try {
165-
modelInfo = await this._loadWithTokenizerAndModel(model, config);
162+
modelInfo = await this._loadWithTokenizerAndModel(model);
166163
} catch (directError) {
167164
// Direct approach failed, fall back to pipeline
168165
console.warn('Direct model loading failed, falling back to pipeline:', directError);
169-
modelInfo = await this._loadWithPipeline(model, config);
166+
modelInfo = await this._loadWithPipeline(model);
170167
}
171168
} else {
172169
// WebGPU not available - use pipeline approach for simplicity
173-
modelInfo = await this._loadWithPipeline(model, config);
170+
modelInfo = await this._loadWithPipeline(model);
174171
}
175172

176173
// Validate dimensions if specified
@@ -207,7 +204,8 @@ export class EmbeddingsModule extends BaseModule {
207204
* @param {any} config - Device configuration
208205
* @returns {Promise<EmbeddingModelInfo>} Model information
209206
*/
210-
private async _loadWithTokenizerAndModel(model: string, config: any): Promise<EmbeddingModelInfo> {
207+
// private async _loadWithTokenizerAndModel(model: string, config: DeviceConfig): Promise<EmbeddingModelInfo> { FIX ME:
208+
private async _loadWithTokenizerAndModel(model: string): Promise<EmbeddingModelInfo> {
211209
// Progress callback for loading
212210
const progressCallback = (component: string) => (progress: any) => {
213211
this.progressTracker.update({
@@ -228,8 +226,6 @@ export class EmbeddingsModule extends BaseModule {
228226

229227
// Load model with optimal configuration
230228
const embeddingModel = await AutoModel.from_pretrained(model, {
231-
...(config.device ? { device: config.device } : {}),
232-
...(config.dtype ? { dtype: config.dtype } : {}),
233229
progress_callback: progressCallback('model')
234230
});
235231

@@ -250,11 +246,12 @@ export class EmbeddingsModule extends BaseModule {
250246
* @param {any} config - Device configuration
251247
* @returns {Promise<EmbeddingModelInfo>} Model information
252248
*/
253-
private async _loadWithPipeline(model: string, config: any): Promise<EmbeddingModelInfo> {
249+
// private async _loadWithPipeline(model: string, config: any): Promise<EmbeddingModelInfo> {
250+
private async _loadWithPipeline(model: string): Promise<EmbeddingModelInfo> {
254251
// Load using feature-extraction pipeline
255252
const embeddingPipeline = await pipeline('feature-extraction', model, {
256-
...(config.device ? { device: config.device } : {}),
257-
...(config.dtype ? { dtype: config.dtype } : {}),
253+
// ...(config.device ? { device: config.device } : {}),
254+
// ...(config.dtype ? { dtype: config.dtype } : {}),
258255
progress_callback: (progress: any) => {
259256
this.progressTracker.update({
260257
status: 'loading',
@@ -459,9 +456,6 @@ export class EmbeddingsModule extends BaseModule {
459456
// Remove from registry
460457
this.embeddingModels.delete(model);
461458

462-
// Try to trigger garbage collection
463-
this.triggerGC();
464-
465459
this.progressTracker.update({
466460
status: 'offloaded',
467461
type: 'embedding_model',

src/GenerationModule.ts

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,10 @@ export class GenerationModule extends BaseModule {
5757

5858
const { models = [], lazyLoad = false } = options;
5959

60-
// Check hardware capabilities
61-
const capabilities = await this.webgpuChecker.check();
62-
6360
this.progressTracker.update({
6461
status: 'init',
6562
type: 'generation_module',
66-
message: `Hardware check: WebGPU ${capabilities.isWebGPUSupported ? 'available' : 'not available'}`
63+
// message: `Hardware check: WebGPU ${capabilities.isWebGPUSupported ? 'available' : 'not available'}` FIX ME:
6764
});
6865

6966
// Load first model if specified and not using lazy loading
@@ -122,14 +119,14 @@ export class GenerationModule extends BaseModule {
122119

123120
try {
124121
// Check hardware capabilities
125-
const capabilities = await this.webgpuChecker.check();
122+
// const capabilities = await this.webgpuChecker.check();
126123

127124
// Get optimal config (or use user-provided quantization)
128-
const config = await this.getOptimalDeviceConfig();
125+
// const config = await this.getOptimalDeviceConfig();
129126
const modelConfig: Record<string, any> = {
130127
// Only specify device and dtype if we have definitive information
131-
...(config.device ? { device: config.device } : {}),
132-
...(config.dtype || quantization ? { dtype: quantization || config.dtype } : {})
128+
// ...(config.device ? { device: config.device } : {}),
129+
// ...(config.dtype || quantization ? { dtype: quantization || config.dtype } : {}) FIX ME:
133130
};
134131

135132
// Initialize file progress tracker for this model load
@@ -326,9 +323,6 @@ export class GenerationModule extends BaseModule {
326323
this.activeModel = null;
327324
}
328325

329-
// Try to trigger garbage collection
330-
this.triggerGC();
331-
332326
this.progressTracker.update({
333327
status: 'offloaded',
334328
type: 'model',

src/Models.ts

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import { DeviceType } from "./Runtime";
2+
3+
// Precision data types distributed as a union
4+
export type _FullPrecisionDType = 'fp32';
5+
export type _HalfPrecisionDType = 'fp16';
6+
export type _8BitPrecisionDType = 'q8' | 'in8' | 'uint8';
7+
export type _4BitPrecisionDType = 'q4' | 'bnb4' | 'q4f16';
8+
9+
// Combine into a full union
10+
export type DType = _FullPrecisionDType | _HalfPrecisionDType | _8BitPrecisionDType | _4BitPrecisionDType;
11+
12+
/**
13+
* Device configuration
14+
*/
15+
export interface DeviceConfig {
16+
device?: DeviceType;
17+
dtype?: DType;
18+
}
19+
20+
/**
21+
* Task type for the models
22+
*/
23+
type TaskType =
24+
'text-generation'
25+
| 'embedding-generation'
26+
| 'text-to-speech'
27+
| 'speech-to-text'
28+
29+
/**
30+
* Interface for the model
31+
*/
32+
export interface Model {
33+
repository: string,
34+
quantization: DType,
35+
taskType: TaskType,
36+
[key: string]: string,
37+
}
38+
39+
/**
40+
* Interfaces for supported models
41+
*/
42+
type LanguageModelsType =
43+
'onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX'
44+
| 'onnx-community/Llama-3.2-1B-Instruct-q4f16'
45+
| 'HuggingFaceTB/SmolLM2-1.7B-Instruct'
46+
47+
type VisionLanguageModelsType =
48+
'HuggingFaceTB/SmolVLM-256M-Instruct'
49+
50+
type TextEmbeddingModelsType =
51+
'nomic-ai/nomic-embed-text-v1.5'
52+
53+
type MultiModalEmbeddingModelsType =
54+
'jinaai/jina-clip-v1'
55+
56+
type TextToSpeechModelsType =
57+
'onnx-community/Kokoro-82M-v1.0-ONNX'
58+
59+
type SpeechToTextModelsType =
60+
'onnx-community/moonshine-base-ONNX'
61+
| 'onnx-community/whisper-tiny.en'
62+
63+
/**
64+
* Models Natively Supported by TinyLM
65+
*/
66+
type SupportedModelsType = LanguageModelsType | VisionLanguageModelsType | TextEmbeddingModelsType | MultiModalEmbeddingModelsType | TextToSpeechModelsType | SpeechToTextModelsType
67+
68+
type ModelCatalogType = Record<SupportedModelsType, Model>
69+
70+
71+
/**
72+
* Model Catalog - Maps all supported models to their configurations
73+
*/
74+
export const modelCatalog: ModelCatalogType = {
75+
// Language Models
76+
"onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX": {
77+
repository: "onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX",
78+
quantization: "fp16",
79+
taskType: "text-generation",
80+
},
81+
"onnx-community/Llama-3.2-1B-Instruct-q4f16": {
82+
repository: "onnx-community/Llama-3.2-1B-Instruct-q4f16",
83+
quantization: "q4f16",
84+
taskType: "text-generation",
85+
},
86+
"HuggingFaceTB/SmolLM2-1.7B-Instruct": {
87+
repository: "HuggingFaceTB/SmolLM2-1.7B-Instruct",
88+
quantization: "q8",
89+
taskType: "text-generation",
90+
},
91+
92+
// Vision Language Models
93+
"HuggingFaceTB/SmolVLM-256M-Instruct": {
94+
repository: "HuggingFaceTB/SmolVLM-256M-Instruct",
95+
quantization: "fp16",
96+
taskType: "text-generation",
97+
},
98+
99+
// Text Embedding Models
100+
"nomic-ai/nomic-embed-text-v1.5": {
101+
repository: "nomic-ai/nomic-embed-text-v1.5",
102+
quantization: "fp32",
103+
taskType: "embedding-generation",
104+
},
105+
106+
// Multi-Modal Embedding Models
107+
"jinaai/jina-clip-v1": {
108+
repository: "jinaai/jina-clip-v1",
109+
quantization: "fp16",
110+
taskType: "embedding-generation",
111+
},
112+
113+
// Text-to-Speech Models
114+
"onnx-community/Kokoro-82M-v1.0-ONNX": {
115+
repository: "onnx-community/Kokoro-82M-v1.0-ONNX",
116+
quantization: "fp32",
117+
taskType: "text-to-speech",
118+
},
119+
120+
// Speech-to-Text Models
121+
"onnx-community/moonshine-base-ONNX": {
122+
repository: "onnx-community/moonshine-base-ONNX",
123+
quantization: "fp16",
124+
taskType: "speech-to-text",
125+
},
126+
"onnx-community/whisper-tiny.en": {
127+
repository: "onnx-community/whisper-tiny.en",
128+
quantization: "q8",
129+
taskType: "speech-to-text",
130+
}
131+
}

0 commit comments

Comments
 (0)