Skip to content

Commit 82545f5

Browse files
committed
feat: decouple model lifecycle management
1 parent b7cd9d4 commit 82545f5

File tree

7 files changed

+520
-1019
lines changed

7 files changed

+520
-1019
lines changed

src/AudioModule.ts

Lines changed: 31 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
import { BaseModule } from './BaseModule';
22
import { TTSEngine, AudioChunk } from './TTSEngine';
3-
import { SpeechCreateOptions, SpeechResult, SpeechStreamResult } from './types';
3+
import { ModelManager } from './ModelManager';
4+
import {
5+
SpeechCreateOptions,
6+
SpeechResult,
7+
SpeechStreamResult,
8+
ModelType
9+
} from './types';
410
// import { splitTextIntoSentences, ensureSafeTokenLength } from './TextSplitter';
511

612
/**
713
* Audio module for TinyLM
814
*/
915
export class AudioModule extends BaseModule {
10-
private ttsEngine: TTSEngine;
16+
private modelManager: ModelManager;
1117
private activeModel: string | null = null;
12-
private modelRegistry: Map<string, boolean> = new Map();
13-
private modelIsLoading: boolean = false;
1418

1519
/**
1620
* Create a new audio module
17-
* @param {any} tinyLM - Parent TinyLM instance
21+
* @param {Object} options - Module options
22+
* @param {ModelManager} options.modelManager - Model manager instance
1823
*/
19-
constructor(tinyLM: any) {
20-
super(tinyLM);
21-
this.ttsEngine = new TTSEngine();
24+
constructor(options: { modelManager: ModelManager }) {
25+
super(options);
26+
this.modelManager = options.modelManager;
2227
}
2328

2429
/**
@@ -67,92 +72,18 @@ export class AudioModule extends BaseModule {
6772
throw new Error('Model identifier is required');
6873
}
6974

70-
// Return if already loading
71-
if (this.modelIsLoading) {
72-
throw new Error('Another model is currently loading');
73-
}
74-
75-
// Set as active and return if already loaded
76-
if (this.modelRegistry.get(model) === true) {
77-
this.activeModel = model;
78-
79-
this.progressTracker.update({
80-
status: 'ready',
81-
type: 'tts_model',
82-
progress: 1,
83-
percentComplete: 100,
84-
message: `Model ${model} is already loaded`
85-
});
86-
87-
return true;
88-
}
89-
90-
// Set loading state
91-
this.modelIsLoading = true;
92-
this.activeModel = model;
93-
9475
try {
95-
// Check hardware capabilities
96-
const capabilities = await this.webgpuChecker.check();
97-
98-
// Get optimal config
99-
const config = await this.getOptimalDeviceConfig();
100-
101-
// For TTS models, always use CPU in Node.js environment or proper fallback
102-
const isNode = this.isNodeEnvironment();
103-
104-
// Determine device explicitly
105-
const device = isNode ? "cpu" :
106-
!capabilities.isWebGPUSupported ? "wasm" : "webgpu";
107-
108-
// Initial progress message
109-
this.progressTracker.update({
110-
status: 'loading',
111-
type: 'tts_model',
112-
progress: 0,
113-
percentComplete: 0,
114-
message: `Loading TTS model ${model} (device="${device}", dtype="${"fp32"}")`
76+
const success = await this.modelManager.loadAudioModel({
77+
model,
78+
type: ModelType.Audio
11579
});
116-
117-
// Load the model directly by name
118-
await this.ttsEngine.loadModel(model, {
119-
onProgress: (progress: any) => {
120-
this.progressTracker.update({
121-
status: 'loading',
122-
type: 'tts_model',
123-
message: `Loading TTS model: ${model}`,
124-
progress: progress.progress,
125-
percentComplete: progress.progress ? Math.round(progress.progress * 100) : undefined,
126-
...progress
127-
});
128-
},
129-
device: device,
130-
dtype: "fp32", // TEMPFIX: config.dtype
131-
});
132-
133-
// Register the model as loaded
134-
this.modelRegistry.set(model, true);
135-
136-
this.progressTracker.update({
137-
status: 'ready',
138-
type: 'tts_model',
139-
progress: 1,
140-
percentComplete: 100,
141-
message: `TTS model ${model} loaded successfully`
142-
});
143-
144-
return true;
80+
if (success) {
81+
this.activeModel = model;
82+
}
83+
return success;
14584
} catch (error) {
14685
const errorMessage = error instanceof Error ? error.message : String(error);
147-
this.progressTracker.update({
148-
status: 'error',
149-
type: 'tts_model',
150-
message: `Error loading TTS model ${model}: ${errorMessage}`
151-
});
152-
15386
throw new Error(`Failed to load TTS model ${model}: ${errorMessage}`);
154-
} finally {
155-
this.modelIsLoading = false;
15687
}
15788
}
15889

@@ -168,7 +99,7 @@ export class AudioModule extends BaseModule {
16899
voice = 'af',
169100
response_format = 'mp3',
170101
speed = 1.0,
171-
stream = false // New streaming parameter
102+
stream = false
172103
} = options;
173104

174105
// Load model if specified and different from current
@@ -177,7 +108,7 @@ export class AudioModule extends BaseModule {
177108
}
178109

179110
// Check if a model is loaded
180-
if (!this.activeModel || !this.modelRegistry.get(this.activeModel)) {
111+
if (!this.activeModel || !this.modelManager.isAudioModelLoaded(this.activeModel)) {
181112
throw new Error('No TTS model loaded. Specify a model or call loadModel() first.');
182113
}
183114

@@ -189,9 +120,10 @@ export class AudioModule extends BaseModule {
189120

190121
try {
191122
const startTime = Date.now();
123+
const ttsEngine = this.modelManager.getTTSEngine();
192124

193125
// Generate speech with or without streaming
194-
const result = await this.ttsEngine.generateSpeech(input, {
126+
const result = await ttsEngine.generateSpeech(input, {
195127
voice,
196128
speed,
197129
stream
@@ -252,72 +184,22 @@ export class AudioModule extends BaseModule {
252184

253185
/**
254186
* Offload a TTS model from memory
255-
* @param {Object} options - Offload options
187+
* @param {string} model - Model identifier
256188
* @returns {Promise<boolean>} Success status
257189
*/
258-
async offloadModel(options: { model: string }): Promise<boolean> {
259-
const { model } = options;
260-
261-
if (!model) {
262-
throw new Error('Model identifier is required');
263-
}
264-
265-
if (!this.modelRegistry.has(model)) {
266-
return false;
190+
async offloadModel(model: string): Promise<boolean> {
191+
const success = await this.modelManager.offloadAudioModel(model);
192+
if (success && this.activeModel === model) {
193+
this.activeModel = null;
267194
}
268-
269-
this.progressTracker.update({
270-
status: 'offloading',
271-
type: 'tts_model',
272-
message: `Offloading TTS model ${model}`
273-
});
274-
275-
try {
276-
// Remove from registry
277-
this.modelRegistry.delete(model);
278-
279-
// Clear current model if it's the active one
280-
if (this.activeModel === model) {
281-
this.activeModel = null;
282-
}
283-
284-
// Try to trigger garbage collection
285-
this.triggerGC();
286-
287-
this.progressTracker.update({
288-
status: 'offloaded',
289-
type: 'tts_model',
290-
message: `TTS model ${model} removed from memory`
291-
});
292-
293-
return true;
294-
} catch (error) {
295-
const errorMessage = error instanceof Error ? error.message : String(error);
296-
this.progressTracker.update({
297-
status: 'error',
298-
type: 'tts_model',
299-
message: `Error offloading TTS model ${model}: ${errorMessage}`
300-
});
301-
302-
return false;
303-
}
304-
}
305-
306-
/**
307-
* Get active model identifier
308-
* @returns {string|null} Active model identifier
309-
*/
310-
getActiveModel(): string | null {
311-
return this.activeModel;
195+
return success;
312196
}
313197

314198
/**
315199
* Get list of loaded TTS models
316200
* @returns {string[]} Array of model identifiers
317201
*/
318202
getLoadedModels(): string[] {
319-
return Array.from(this.modelRegistry.entries())
320-
.filter(([_, loaded]) => loaded)
321-
.map(([model, _]) => model);
203+
return this.modelManager.getLoadedModels(ModelType.Audio);
322204
}
323205
}

src/BaseModule.ts

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
import { WebGPUChecker } from './WebGPUChecker';
77
import { ProgressTracker } from './ProgressTracker';
8-
import { tryGarbageCollection, EnvironmentInfo } from './utils';
8+
import { tryGarbageCollection, detectEnvironment, EnvironmentInfo } from './utils';
99
import { DeviceConfig } from './types';
10+
import { ModelManager } from './ModelManager';
1011

1112
/**
1213
* Base interface for TinyLM modules
@@ -24,20 +25,19 @@ export interface TinyLMModule {
2425
* Base class for TinyLM modules with common functionality
2526
*/
2627
export abstract class BaseModule implements TinyLMModule {
27-
protected tinyLM: any;
2828
protected progressTracker: ProgressTracker;
2929
protected webgpuChecker: WebGPUChecker;
3030
protected environment: EnvironmentInfo;
3131

3232
/**
3333
* Create a new module
34-
* @param {any} tinyLM - Parent TinyLM instance
34+
* @param {Object} options - Module options
35+
* @param {ModelManager} options.modelManager - Model manager instance
3536
*/
36-
constructor(tinyLM: any) {
37-
this.tinyLM = tinyLM;
38-
this.progressTracker = tinyLM.getProgressTracker();
39-
this.webgpuChecker = tinyLM.getWebGPUChecker();
40-
this.environment = tinyLM.getEnvironment();
37+
constructor(options: { modelManager: ModelManager }) {
38+
this.progressTracker = options.modelManager['progressTracker'];
39+
this.webgpuChecker = options.modelManager['webgpuChecker'];
40+
this.environment = detectEnvironment();
4141
}
4242

4343
/**

0 commit comments

Comments
 (0)