Skip to content

Commit 159e1f6

Browse files
refactor: extract embedding preparation and support mixed embeddings
- Extract embedding generation into dedicated prepareEmbeddings method - Support mixed embeddings arrays where some items have embeddings and others are null, generating missing ones in batch while maintaining order - Separate concerns: prepareEmbeddings handles generation only, validate handles all validation logic
1 parent fa3101d commit 159e1f6

1 file changed

Lines changed: 79 additions & 40 deletions

File tree

src/Models/Collection.php

Lines changed: 79 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,11 @@ public function add(
9090
}
9191
}
9292

93+
$preparedEmbeddings = $this->prepareEmbeddings($embeddings, $documents);
94+
9395
$validated = $this->validate(
9496
ids: $ids,
95-
embeddings: $embeddings,
97+
embeddings: $preparedEmbeddings,
9698
metadatas: $metadatas,
9799
documents: $documents,
98100
requireEmbeddingsOrDocuments: true,
@@ -138,9 +140,11 @@ public function update(
138140
}
139141
}
140142

143+
$preparedEmbeddings = $this->prepareEmbeddings($embeddings, $documents);
144+
141145
$validated = $this->validate(
142146
ids: $ids,
143-
embeddings: $embeddings,
147+
embeddings: $preparedEmbeddings,
144148
metadatas: $metadatas,
145149
documents: $documents,
146150
requireEmbeddingsOrDocuments: false,
@@ -186,9 +190,11 @@ public function upsert(
186190
}
187191
}
188192

193+
$preparedEmbeddings = $this->prepareEmbeddings($embeddings, $documents);
194+
189195
$validated = $this->validate(
190196
ids: $ids,
191-
embeddings: $embeddings,
197+
embeddings: $preparedEmbeddings,
192198
metadatas: $metadatas,
193199
documents: $documents,
194200
requireEmbeddingsOrDocuments: true,
@@ -318,22 +324,10 @@ public function query(
318324
);
319325
}
320326

321-
$finalEmbeddings = [];
327+
$finalEmbeddings = $this->prepareEmbeddings($queryEmbeddings, $queryTexts);
322328

323-
if ($queryEmbeddings == null) {
324-
if ($this->embeddingFunction == null) {
325-
throw new InvalidArgumentException(
326-
'You must provide an embedding function if you did not provide embeddings'
327-
);
328-
} elseif ($queryTexts != null) {
329-
$finalEmbeddings = $this->embeddingFunction->generate($queryTexts);
330-
} else {
331-
throw new InvalidArgumentException(
332-
'If you did not provide queryEmbeddings, you must provide queryTexts'
333-
);
334-
}
335-
} else {
336-
foreach ($queryEmbeddings as $i => $embedding) {
329+
if ($finalEmbeddings !== null) {
330+
foreach ($finalEmbeddings as $i => $embedding) {
337331
if (!is_array($embedding)) {
338332
throw new InvalidArgumentException(sprintf(
339333
"Expected query embedding at index %d to be an array, got %s",
@@ -343,7 +337,7 @@ public function query(
343337
}
344338

345339
foreach ($embedding as $j => $value) {
346-
if (!is_float($value)) {
340+
if (!is_float($value) && !is_int($value)) {
347341
throw new InvalidArgumentException(sprintf(
348342
"Expected query embedding value at index %d.%d to be a float, got %s",
349343
$i,
@@ -353,10 +347,8 @@ public function query(
353347
}
354348
}
355349
}
356-
$finalEmbeddings = $queryEmbeddings;
357350
}
358351

359-
360352
$request = new QueryItemsRequest(
361353
where: $where,
362354
whereDocument: $whereDocument,
@@ -383,10 +375,69 @@ public function setEmbeddingFunction(EmbeddingFunction $embeddingFunction): void
383375
$this->embeddingFunction = $embeddingFunction;
384376
}
385377

378+
/**
379+
* Prepares embeddings by generating missing ones in batch.
380+
*
381+
* @param array|null $embeddings Existing embeddings (may contain nulls for missing ones)
382+
* @param array|null $texts Texts to generate embeddings from (documents or queryTexts)
383+
* @return array|null Prepared embeddings array with all nulls filled in, or null if texts is null
384+
*/
385+
protected function prepareEmbeddings(?array $embeddings, ?array $texts): ?array
386+
{
387+
if ($texts === null) {
388+
return $embeddings;
389+
}
390+
391+
if (empty($texts)) {
392+
return $embeddings;
393+
}
394+
395+
if ($embeddings === null || empty($embeddings)) {
396+
return $this->embeddingFunction->generate($texts);
397+
}
398+
399+
$missingIndices = [];
400+
$textsToEmbed = [];
401+
402+
foreach ($embeddings as $i => $embedding) {
403+
if ($embedding === null) {
404+
if (!isset($texts[$i]) || $texts[$i] === null) {
405+
throw new InvalidArgumentException(sprintf('Cannot generate embedding at index %d: no text provided', $i));
406+
}
407+
$missingIndices[] = $i;
408+
$textsToEmbed[] = $texts[$i];
409+
}
410+
}
411+
412+
if (empty($missingIndices)) {
413+
return $embeddings;
414+
}
415+
416+
$generatedEmbeddings = $this->embeddingFunction->generate($textsToEmbed);
417+
418+
$finalEmbeddings = [];
419+
$generatedIndex = 0;
420+
421+
foreach ($embeddings as $i => $embedding) {
422+
if ($embedding === null) {
423+
$finalEmbeddings[] = $generatedEmbeddings[$generatedIndex++];
424+
} else {
425+
$finalEmbeddings[] = $embedding;
426+
}
427+
}
428+
429+
return $finalEmbeddings;
430+
}
431+
386432
/**
387433
* Validates the inputs to the add, upsert, and update methods.
388434
*
389-
* @return array{ids: string[], embeddings: int[][], metadatas: array[], documents: string[]}
435+
* @return array{
436+
* ids: string[],
437+
* embeddings: int[][],
438+
* metadatas: array[],
439+
* documents: string[]
440+
* }
390441
*/
391442
protected function validate(
392443
array $ids,
@@ -428,19 +479,7 @@ protected function validate(
428479
}
429480

430481
// Validate embeddings
431-
if ($embeddings == null) {
432-
if ($this->embeddingFunction == null) {
433-
throw new InvalidArgumentException(
434-
'You must provide an embedding function if you did not provide embeddings'
435-
);
436-
} elseif ($documents != null) {
437-
$finalEmbeddings = $this->embeddingFunction->generate($documents);
438-
} else {
439-
throw new InvalidArgumentException(
440-
'If you did not provide embeddings, you must provide documents'
441-
);
442-
}
443-
} else {
482+
if ($embeddings !== null) {
444483
foreach ($embeddings as $i => $embedding) {
445484
if (!is_array($embedding)) {
446485
throw new InvalidArgumentException(sprintf(
@@ -451,20 +490,19 @@ protected function validate(
451490
}
452491

453492
foreach ($embedding as $j => $value) {
454-
if (!is_float($value)) {
493+
if (!is_float($value) && !is_int($value)) {
455494
throw new InvalidArgumentException(sprintf(
456-
"Expected embedding value at index %d.%d to be a float, got %s",
495+
"Expected embedding value at index %d.%d to be a number, got %s",
457496
$i,
458497
$j,
459498
gettype($value)
460499
));
461500
}
462501
}
463502
}
464-
465-
$finalEmbeddings = $embeddings;
466503
}
467504

505+
// Validate ids
468506
$ids = array_map(function ($id) {
469507
if (is_object($id) && method_exists($id, '__toString')) {
470508
$id = (string) $id;
@@ -478,6 +516,7 @@ protected function validate(
478516
return $id;
479517
}, $ids);
480518

519+
// Validate unique ids
481520
$uniqueIds = array_unique($ids);
482521
if (count($uniqueIds) !== count($ids)) {
483522
$duplicateIds = array_filter($ids, function ($id) use ($ids) {
@@ -488,7 +527,7 @@ protected function validate(
488527

489528
return [
490529
'ids' => $ids,
491-
'embeddings' => $finalEmbeddings,
530+
'embeddings' => $embeddings,
492531
'metadatas' => $metadatas,
493532
'documents' => $documents,
494533
];

0 commit comments

Comments
 (0)