|
| 1 | +""" |
| 2 | +An TinyML detection technique using Efficientdet model. |
| 3 | +""" |
| 4 | +import time |
| 5 | +from typing import Any |
| 6 | +import cv2 |
| 7 | +import numpy |
| 8 | +from tflite_runtime.interpreter import Interpreter |
| 9 | +from .base_detector_strategy import BaseDetectorStrategy, DetectorResult |
| 10 | + |
| 11 | + |
| 12 | +class EfficientdetStrategy(BaseDetectorStrategy): |
| 13 | + """ |
| 14 | + The Efficientdet strategy for detection of objects. |
| 15 | + """ |
| 16 | + MODEL_PATH: str = "models/efficientdet_1.tflite" |
| 17 | + LABEL_PATH: str = "models/efficientdet_1_labelmap.txt" |
| 18 | + DETECTION_THRES: float = 0.35 |
| 19 | + |
| 20 | + @classmethod |
| 21 | + def detect_humans(cls, frame: numpy.ndarray) -> DetectorResult: |
| 22 | + """This method detects if there are any humans in the frame.""" |
| 23 | + # Create an model interpreter. |
| 24 | + interpreter: Interpreter = Interpreter(model_path=cls.MODEL_PATH) |
| 25 | + interpreter.allocate_tensors() |
| 26 | + |
| 27 | + # Get model input and output details. |
| 28 | + input_details: list[dict[str, Any]] = interpreter.get_input_details() |
| 29 | + output_details: list[dict[str, Any]] = interpreter.get_output_details() |
| 30 | + _, input_height, input_width, _ = input_details[0]['shape'] |
| 31 | + |
| 32 | + # Prepare image for input-tensor. |
| 33 | + image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| 34 | + image = cv2.resize(image, (input_width, input_height), interpolation=cv2.INTER_AREA) |
| 35 | + image_height, image_width = image.shape[:2] |
| 36 | + |
| 37 | + # Apply the frame into first tensor of the model. |
| 38 | + input_data = numpy.expand_dims(image, axis=0) |
| 39 | + interpreter.set_tensor(input_details[0]['index'], input_data) |
| 40 | + |
| 41 | + # Calculate the output tensor. |
| 42 | + interpreter.invoke() |
| 43 | + |
| 44 | + # Recieve the output. |
| 45 | + boxes = interpreter.get_tensor(output_details[0]['index'])[0] |
| 46 | + classes = interpreter.get_tensor(output_details[1]['index'])[0] |
| 47 | + scores = interpreter.get_tensor(output_details[2]['index'])[0] |
| 48 | + |
| 49 | + # Read label-map. |
| 50 | + with open(cls.LABEL_PATH, 'r', encoding="utf-8") as labelmap: |
| 51 | + labels = [line.strip() for line in labelmap.readlines()] |
| 52 | + |
| 53 | + # Create color legend for each class type. |
| 54 | + colors = numpy.random.randint(0, 255, size=(len(labels), 3), dtype='uint8') |
| 55 | + |
| 56 | + # Convert RGB to BGR again. |
| 57 | + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
| 58 | + |
| 59 | + # Travers through detections. |
| 60 | + detection_regions: list[tuple[int, int, int, int]] = [] |
| 61 | + for score, box, pred_class in zip(scores, boxes, classes): |
| 62 | + if score < cls.DETECTION_THRES: |
| 63 | + continue |
| 64 | + |
| 65 | + class_name = labels[int(pred_class)] |
| 66 | + if class_name == "person": |
| 67 | + min_y = round(box[0] * image_height) |
| 68 | + min_x = round(box[1] * image_width) |
| 69 | + max_y = round(box[2] * image_height) |
| 70 | + max_x = round(box[3] * image_width) |
| 71 | + detection_regions.append((min_x, max_x, min_y, max_y)) |
| 72 | + |
| 73 | + cv2.rectangle(image, (min_x, min_y), (max_x, max_y), (0, 255, 0), 2) |
| 74 | + |
| 75 | + result = DetectorResult( |
| 76 | + image=image, |
| 77 | + human_found=len(detection_regions) > 0, |
| 78 | + regions=detection_regions, |
| 79 | + num_detections=len(detection_regions), |
| 80 | + ) |
| 81 | + return result |
0 commit comments