Skip to content

Commit ed79a16

Browse files
Add proper types for InferenceSession
1 parent 587b8c2 commit ed79a16

1 file changed

Lines changed: 37 additions & 38 deletions

File tree

src/OnnxRuntime/InferenceSession.php

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,6 @@ class InferenceSession
1515
private array $inputs;
1616
private array $outputs;
1717

18-
private const ONNX_TENSOR_TYPE_TO_PHP_TENSOR_MAP = [
19-
1 => Tensor::float32,
20-
2 => Tensor::uint8,
21-
3 => Tensor::int8,
22-
4 => Tensor::uint16,
23-
5 => Tensor::int16,
24-
6 => Tensor::int32,
25-
7 => Tensor::int64,
26-
8 => 19, // string, but it has no mapping in tensor
27-
9 => Tensor::bool,
28-
10 => Tensor::float16,
29-
11 => Tensor::float64,
30-
12 => Tensor::uint32,
31-
13 => Tensor::uint64,
32-
14 => Tensor::complex64,
33-
15 => Tensor::complex128,
34-
16 => Tensor::float16 // bfloat16 though
35-
];
36-
3718
public function __construct(
3819
$path,
3920
$enableCpuMemArena = true,
@@ -201,17 +182,17 @@ public function run($outputNames, $inputFeed, $logSeverityLevel = null, $logVerb
201182
return $output;
202183
}
203184

204-
public function inputs()
185+
public function inputs(): array
205186
{
206187
return $this->inputs;
207188
}
208189

209-
public function outputs()
190+
public function outputs(): array
210191
{
211192
return $this->outputs;
212193
}
213194

214-
public function modelmeta()
195+
public function modelmeta(): array
215196
{
216197
$keys = $this->ffi->new('char**');
217198
$numKeys = $this->ffi->new('int64_t');
@@ -268,7 +249,7 @@ public function modelmeta()
268249
}
269250

270251
// return value has double underscore like Python
271-
public function endProfiling()
252+
public function endProfiling(): string
272253
{
273254
$out = $this->ffi->new('char*');
274255
$this->checkStatus(($this->api->SessionEndProfiling)($this->session, $this->allocator, \FFI::addr($out)));
@@ -277,7 +258,7 @@ public function endProfiling()
277258

278259
// no way to set providers with C API yet
279260
// so we can return all available providers
280-
public function providers()
261+
public function providers(): array
281262
{
282263
$outPtr = $this->ffi->new('char**');
283264
$lengthPtr = $this->ffi->new('int');
@@ -291,7 +272,7 @@ public function providers()
291272
return $providers;
292273
}
293274

294-
private function loadSession($path, $sessionOptions)
275+
private function loadSession($path, $sessionOptions): ?CData
295276
{
296277
$session = $this->ffi->new('OrtSession*');
297278
if (is_resource($path) && get_resource_type($path) == 'stream') {
@@ -303,14 +284,14 @@ private function loadSession($path, $sessionOptions)
303284
return $session;
304285
}
305286

306-
private function loadAllocator()
287+
private function loadAllocator(): ?CData
307288
{
308289
$allocator = $this->ffi->new('OrtAllocator*');
309290
$this->checkStatus(($this->api->GetAllocatorWithDefaultOptions)(\FFI::addr($allocator)));
310291
return $allocator;
311292
}
312293

313-
private function loadInputs()
294+
private function loadInputs(): array
314295
{
315296
$inputs = [];
316297
$numInputNodes = $this->ffi->new('size_t');
@@ -327,7 +308,7 @@ private function loadInputs()
327308
return $inputs;
328309
}
329310

330-
private function loadOutputs()
311+
private function loadOutputs(): array
331312
{
332313
$outputs = [];
333314
$numOutputNodes = $this->ffi->new('size_t');
@@ -346,7 +327,7 @@ private function loadOutputs()
346327
return $outputs;
347328
}
348329

349-
private function convertInputTensorToOnnxTensor($inputFeed, &$refs)
330+
private function convertInputTensorToOnnxTensor($inputFeed, &$refs): ?CData
350331
{
351332
$allocatorInfo = $this->ffi->new('OrtMemoryInfo*');
352333
$this->checkStatus(($this->api->CreateCpuMemoryInfo)(1, 0, \FFI::addr($allocatorInfo)));
@@ -396,7 +377,7 @@ private function convertInputTensorToOnnxTensor($inputFeed, &$refs)
396377
if (isset($inputTypes[$inp['type']])) {
397378
$typeEnum = $inputTypes[$inp['type']];
398379
$castType = $this->castTypes()[$typeEnum];
399-
$phpTensorType = self::ONNX_TENSOR_TYPE_TO_PHP_TENSOR_MAP[$typeEnum];
380+
$phpTensorType = $this->phpTensorTypes()[$typeEnum];
400381
$input = $input->to($phpTensorType);
401382
} else {
402383
$this->unsupportedType('input', $inp['type']);
@@ -435,7 +416,7 @@ private function fillStringTensorValues(Tensor $input, $ptr, &$refs): void
435416
}
436417
}
437418

438-
private function createNodeNames($names, &$refs)
419+
private function createNodeNames($names, &$refs): CData
439420
{
440421
$namesSize = count($names);
441422
$ptr = $this->ffi->new("char*[$namesSize]");
@@ -447,7 +428,7 @@ private function createNodeNames($names, &$refs)
447428
return $ptr;
448429
}
449430

450-
private function cstring($str)
431+
private function cstring($str): CData
451432
{
452433
$bytes = strlen($str) + 1;
453434
// TODO fix?
@@ -488,7 +469,7 @@ private function createFromOnnxValue($outPtr)
488469
$this->unsupportedType('element', $type);
489470
}
490471

491-
$phpTensorType = self::ONNX_TENSOR_TYPE_TO_PHP_TENSOR_MAP[$type];
472+
$phpTensorType = $this->phpTensorTypes()[$type];
492473

493474
$buffer = Tensor::newBuffer($outputTensorSize, $phpTensorType);
494475

@@ -540,7 +521,7 @@ private function createFromOnnxValue($outPtr)
540521
}
541522
}
542523

543-
private function createStringsFromOnnxValue($outPtr, $outputTensorSize)
524+
private function createStringsFromOnnxValue($outPtr, $outputTensorSize): array
544525
{
545526
$len = $this->ffi->new('size_t');
546527
$this->checkStatus(($this->api->GetStringTensorDataLength)($outPtr, \FFI::addr($len)));
@@ -560,7 +541,7 @@ private function createStringsFromOnnxValue($outPtr, $outputTensorSize)
560541
return $result;
561542
}
562543

563-
private static function checkStatus($status)
544+
private static function checkStatus($status): void
564545
{
565546
if (!is_null($status)) {
566547
$message = (self::api()->GetErrorMessage)($status);
@@ -610,7 +591,7 @@ private function nodeInfo($typeinfo)
610591
}
611592
}
612593

613-
private function castTypes()
594+
private function castTypes(): array
614595
{
615596
return [
616597
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => 'float',
@@ -627,7 +608,7 @@ private function castTypes()
627608
];
628609
}
629610

630-
private function elementDataTypes()
611+
private function elementDataTypes(): array
631612
{
632613
return [
633614
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED => 'undefined',
@@ -650,7 +631,25 @@ private function elementDataTypes()
650631
];
651632
}
652633

653-
private function tensorTypeAndShape($tensorInfo)
634+
private function phpTensorTypes(): array
635+
{
636+
return [
637+
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => Tensor::float32,
638+
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => Tensor::uint8,
639+
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => Tensor::int8,
640+
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => Tensor::uint16,
641+
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => Tensor::int16,
642+
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => Tensor::int32,
643+
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => Tensor::int64,
644+
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => Tensor::bool,
645+
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => Tensor::float64,
646+
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => Tensor::uint32,
647+
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => Tensor::uint64,
648+
];
649+
}
650+
651+
652+
private function tensorTypeAndShape($tensorInfo): array
654653
{
655654
$type = $this->ffi->new('ONNXTensorElementDataType');
656655
$this->checkStatus(($this->api->GetTensorElementType)($tensorInfo, \FFI::addr($type)));

0 commit comments

Comments
 (0)