const crypto = require('crypto'); const path = require('path'); let pipelinePromise = null; let _transformers = null; // LRU cache: sha1(text) -> Float32Array, capped at 256 const lru = new Map(); const LRU_MAX = 256; function _lruKey(text) { return crypto.createHash('sha1').update(text).digest('hex'); } function _lruGet(key) { const val = lru.get(key); if (val !== undefined) { // move to back (most recently used) lru.delete(key); lru.set(key, val); } return val; } function _lruSet(key, vec) { if (lru.has(key)) { lru.delete(key); } else if (lru.size >= LRU_MAX) { const firstKey = lru.keys().next().value; lru.delete(firstKey); } lru.set(key, vec); } async function _getPipeline() { if (pipelinePromise) return pipelinePromise; pipelinePromise = (async () => { try { const mod = await import('@xenova/transformers'); _transformers = mod; mod.env.cacheDir = path.join(__dirname, '..', '..', 'node_modules', '.cache', 'transformers'); // Try webgpu first (DirectML on Windows/AMD), fallback to wasm let pipe; try { pipe = await mod.pipeline('feature-extraction', 'Xenova/multilingual-e5-small', { device: 'webgpu', }); console.log('[embeddings] pipeline loaded with device=webgpu'); } catch (gpuErr) { console.warn('[embeddings] webgpu failed, falling back to wasm:', gpuErr.message); pipe = await mod.pipeline('feature-extraction', 'Xenova/multilingual-e5-small', { device: 'wasm', }); console.log('[embeddings] pipeline loaded with device=wasm'); } return pipe; } catch (err) { console.error('[embeddings] failed to load pipeline:', err.message); throw err; } })(); return pipelinePromise; } async function warmup() { try { await _getPipeline(); } catch (err) { console.warn('[embeddings] warmup failed (model will retry on first use):', err.message); } } async function embed(text) { if (!text || typeof text !== 'string') { throw new Error('embed() requires a non-empty string'); } const key = _lruKey(text); const cached = _lruGet(key); if (cached) return cached; const pipe = await _getPipeline(); const result = await pipe(text, { pooling: 'mean', normalize: true }); const vec = result.data instanceof Float32Array ? result.data : new Float32Array(result.data); _lruSet(key, vec); return vec; } async function embedBatch(texts) { if (!Array.isArray(texts) || texts.length === 0) { return []; } // Check cache first const uncached = []; const indices = []; const results = new Array(texts.length); for (let i = 0; i < texts.length; i++) { const key = _lruKey(texts[i]); const cached = _lruGet(key); if (cached) { results[i] = cached; } else { uncached.push(texts[i]); indices.push(i); } } if (uncached.length === 0) { return results; } const pipe = await _getPipeline(); const BATCH_SIZE = 32; for (let start = 0; start < uncached.length; start += BATCH_SIZE) { const batch = uncached.slice(start, start + BATCH_SIZE); const batchResult = await pipe(batch, { pooling: 'mean', normalize: true }); // batchResult.data is a flat array for all batches; shape depends on library version // For Transformers.js v2, when batching, result.data is flat and we need to slice const dim = batch.length > 0 ? Math.floor(batchResult.data.length / batch.length) : 384; for (let b = 0; b < batch.length; b++) { const offset = b * dim; const vec = new Float32Array(batchResult.data.slice(offset, offset + dim)); const originalIdx = indices[start + b]; results[originalIdx] = vec; _lruSet(_lruKey(batch[b]), vec); } } return results; } module.exports = { warmup, embed, embedBatch, };