@@ -90,9 +90,11 @@ public function add(
9090 }
9191 }
9292
93+ $ preparedEmbeddings = $ this ->prepareEmbeddings ($ embeddings , $ documents );
94+
9395 $ validated = $ this ->validate (
9496 ids: $ ids ,
95- embeddings: $ embeddings ,
97+ embeddings: $ preparedEmbeddings ,
9698 metadatas: $ metadatas ,
9799 documents: $ documents ,
98100 requireEmbeddingsOrDocuments: true ,
@@ -138,9 +140,11 @@ public function update(
138140 }
139141 }
140142
143+ $ preparedEmbeddings = $ this ->prepareEmbeddings ($ embeddings , $ documents );
144+
141145 $ validated = $ this ->validate (
142146 ids: $ ids ,
143- embeddings: $ embeddings ,
147+ embeddings: $ preparedEmbeddings ,
144148 metadatas: $ metadatas ,
145149 documents: $ documents ,
146150 requireEmbeddingsOrDocuments: false ,
@@ -186,9 +190,11 @@ public function upsert(
186190 }
187191 }
188192
193+ $ preparedEmbeddings = $ this ->prepareEmbeddings ($ embeddings , $ documents );
194+
189195 $ validated = $ this ->validate (
190196 ids: $ ids ,
191- embeddings: $ embeddings ,
197+ embeddings: $ preparedEmbeddings ,
192198 metadatas: $ metadatas ,
193199 documents: $ documents ,
194200 requireEmbeddingsOrDocuments: true ,
@@ -318,22 +324,10 @@ public function query(
318324 );
319325 }
320326
321- $ finalEmbeddings = [] ;
327+ $ finalEmbeddings = $ this -> prepareEmbeddings ( $ queryEmbeddings , $ queryTexts ) ;
322328
323- if ($ queryEmbeddings == null ) {
324- if ($ this ->embeddingFunction == null ) {
325- throw new InvalidArgumentException (
326- 'You must provide an embedding function if you did not provide embeddings '
327- );
328- } elseif ($ queryTexts != null ) {
329- $ finalEmbeddings = $ this ->embeddingFunction ->generate ($ queryTexts );
330- } else {
331- throw new InvalidArgumentException (
332- 'If you did not provide queryEmbeddings, you must provide queryTexts '
333- );
334- }
335- } else {
336- foreach ($ queryEmbeddings as $ i => $ embedding ) {
329+ if ($ finalEmbeddings !== null ) {
330+ foreach ($ finalEmbeddings as $ i => $ embedding ) {
337331 if (!is_array ($ embedding )) {
338332 throw new InvalidArgumentException (sprintf (
339333 "Expected query embedding at index %d to be an array, got %s " ,
@@ -343,7 +337,7 @@ public function query(
343337 }
344338
345339 foreach ($ embedding as $ j => $ value ) {
346- if (!is_float ($ value )) {
340+ if (!is_float ($ value ) && ! is_int ( $ value ) ) {
347341 throw new InvalidArgumentException (sprintf (
348342 "Expected query embedding value at index %d.%d to be a float, got %s " ,
349343 $ i ,
@@ -353,10 +347,8 @@ public function query(
353347 }
354348 }
355349 }
356- $ finalEmbeddings = $ queryEmbeddings ;
357350 }
358351
359-
360352 $ request = new QueryItemsRequest (
361353 where: $ where ,
362354 whereDocument: $ whereDocument ,
@@ -383,10 +375,69 @@ public function setEmbeddingFunction(EmbeddingFunction $embeddingFunction): void
383375 $ this ->embeddingFunction = $ embeddingFunction ;
384376 }
385377
378+ /**
379+ * Prepares embeddings by generating missing ones in batch.
380+ *
381+ * @param array|null $embeddings Existing embeddings (may contain nulls for missing ones)
382+ * @param array|null $texts Texts to generate embeddings from (documents or queryTexts)
383+ * @return array|null Prepared embeddings array with all nulls filled in, or null if texts is null
384+ */
385+ protected function prepareEmbeddings (?array $ embeddings , ?array $ texts ): ?array
386+ {
387+ if ($ texts === null ) {
388+ return $ embeddings ;
389+ }
390+
391+ if (empty ($ texts )) {
392+ return $ embeddings ;
393+ }
394+
395+ if ($ embeddings === null || empty ($ embeddings )) {
396+ return $ this ->embeddingFunction ->generate ($ texts );
397+ }
398+
399+ $ missingIndices = [];
400+ $ textsToEmbed = [];
401+
402+ foreach ($ embeddings as $ i => $ embedding ) {
403+ if ($ embedding === null ) {
404+ if (!isset ($ texts [$ i ]) || $ texts [$ i ] === null ) {
405+ throw new InvalidArgumentException (sprintf ('Cannot generate embedding at index %d: no text provided ' , $ i ));
406+ }
407+ $ missingIndices [] = $ i ;
408+ $ textsToEmbed [] = $ texts [$ i ];
409+ }
410+ }
411+
412+ if (empty ($ missingIndices )) {
413+ return $ embeddings ;
414+ }
415+
416+ $ generatedEmbeddings = $ this ->embeddingFunction ->generate ($ textsToEmbed );
417+
418+ $ finalEmbeddings = [];
419+ $ generatedIndex = 0 ;
420+
421+ foreach ($ embeddings as $ i => $ embedding ) {
422+ if ($ embedding === null ) {
423+ $ finalEmbeddings [] = $ generatedEmbeddings [$ generatedIndex ++];
424+ } else {
425+ $ finalEmbeddings [] = $ embedding ;
426+ }
427+ }
428+
429+ return $ finalEmbeddings ;
430+ }
431+
386432 /**
387433 * Validates the inputs to the add, upsert, and update methods.
388434 *
389- * @return array{ids: string[], embeddings: int[][], metadatas: array[], documents: string[]}
435+ * @return array{
436+ * ids: string[],
437+ * embeddings: int[][],
438+ * metadatas: array[],
439+ * documents: string[]
440+ * }
390441 */
391442 protected function validate (
392443 array $ ids ,
@@ -428,19 +479,7 @@ protected function validate(
428479 }
429480
430481 // Validate embeddings
431- if ($ embeddings == null ) {
432- if ($ this ->embeddingFunction == null ) {
433- throw new InvalidArgumentException (
434- 'You must provide an embedding function if you did not provide embeddings '
435- );
436- } elseif ($ documents != null ) {
437- $ finalEmbeddings = $ this ->embeddingFunction ->generate ($ documents );
438- } else {
439- throw new InvalidArgumentException (
440- 'If you did not provide embeddings, you must provide documents '
441- );
442- }
443- } else {
482+ if ($ embeddings !== null ) {
444483 foreach ($ embeddings as $ i => $ embedding ) {
445484 if (!is_array ($ embedding )) {
446485 throw new InvalidArgumentException (sprintf (
@@ -451,20 +490,19 @@ protected function validate(
451490 }
452491
453492 foreach ($ embedding as $ j => $ value ) {
454- if (!is_float ($ value )) {
493+ if (!is_float ($ value ) && ! is_int ( $ value ) ) {
455494 throw new InvalidArgumentException (sprintf (
456- "Expected embedding value at index %d.%d to be a float , got %s " ,
495+ "Expected embedding value at index %d.%d to be a number , got %s " ,
457496 $ i ,
458497 $ j ,
459498 gettype ($ value )
460499 ));
461500 }
462501 }
463502 }
464-
465- $ finalEmbeddings = $ embeddings ;
466503 }
467504
505+ // Validate ids
468506 $ ids = array_map (function ($ id ) {
469507 if (is_object ($ id ) && method_exists ($ id , '__toString ' )) {
470508 $ id = (string ) $ id ;
@@ -478,6 +516,7 @@ protected function validate(
478516 return $ id ;
479517 }, $ ids );
480518
519+ // Validate unique ids
481520 $ uniqueIds = array_unique ($ ids );
482521 if (count ($ uniqueIds ) !== count ($ ids )) {
483522 $ duplicateIds = array_filter ($ ids , function ($ id ) use ($ ids ) {
@@ -488,7 +527,7 @@ protected function validate(
488527
489528 return [
490529 'ids ' => $ ids ,
491- 'embeddings ' => $ finalEmbeddings ,
530+ 'embeddings ' => $ embeddings ,
492531 'metadatas ' => $ metadatas ,
493532 'documents ' => $ documents ,
494533 ];
0 commit comments