Skip to content

Commit ae9d8fb

Browse files
No public description
PiperOrigin-RevId: 897871239
1 parent 8db072d commit ae9d8fb

1 file changed

Lines changed: 218 additions & 1 deletion

File tree

official/projects/waste_identification_ml/model_inference/cn_model_run.ipynb

Lines changed: 218 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,19 @@
77
"cell_type": "markdown",
88
"source": [
99
"# CircularNet - Waste identification with instance segmentation\n",
10-
"Welcome to the Instance Segmentation Notebook! This notebook will take you through the steps of running an Instance Segmentation model on Images."
10+
"Welcome to the Instance Segmentation Notebook! This notebook will take you through the steps of running an Instance Segmentation model on Images. \\\n",
11+
"There are two ways to run it :\n",
12+
"1. Pytorch Model : How to quickly run and test the CircularNet model's Pytorch version.\n",
13+
"2. ONNX Model : Running the converted ONNX version of the model."
14+
]
15+
},
16+
{
17+
"metadata": {
18+
"id": "XQVvI_GHHMuo"
19+
},
20+
"cell_type": "markdown",
21+
"source": [
22+
"# 1. Pytorch Model"
1123
]
1224
},
1325
{
@@ -191,6 +203,211 @@
191203
],
192204
"outputs": [],
193205
"execution_count": null
206+
},
207+
{
208+
"metadata": {
209+
"id": "lEK0yWmsHSKL"
210+
},
211+
"cell_type": "markdown",
212+
"source": [
213+
"# 2. How to run ONNX Version of the Model"
214+
]
215+
},
216+
{
217+
"metadata": {
218+
"id": "9z9j-zhcHglk"
219+
},
220+
"cell_type": "markdown",
221+
"source": [
222+
"## Download model and essential codebase"
223+
]
224+
},
225+
{
226+
"metadata": {
227+
"id": "QHXuzvaCHRpV"
228+
},
229+
"cell_type": "code",
230+
"source": [
231+
"!git clone --depth 1 https://github.com/tensorflow/models.git\n",
232+
"!wget https://storage.googleapis.com/tf_model_garden/vision/waste_identification_ml/CN-ModelCheckpoints/ModelRegistry_432x432_March26/inference_model.onnx\n",
233+
"!wget https://storage.googleapis.com/tf_model_garden/vision/waste_identification_ml/CN-ModelCheckpoints/sample_image.jpg\n",
234+
"!mv models/official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/labels50.csv ./"
235+
],
236+
"outputs": [],
237+
"execution_count": null
238+
},
239+
{
240+
"metadata": {
241+
"id": "ul1Zy1s5HjFO"
242+
},
243+
"cell_type": "markdown",
244+
"source": [
245+
"## Install all the required libraries\n"
246+
]
247+
},
248+
{
249+
"metadata": {
250+
"id": "DzY3Zg-lHeQt"
251+
},
252+
"cell_type": "code",
253+
"source": [
254+
"!pip install -q onnx==1.19.1 onnxruntime==1.23.2 supervision tritonclient[http]==2.58.0"
255+
],
256+
"outputs": [],
257+
"execution_count": null
258+
},
259+
{
260+
"metadata": {
261+
"id": "A2LFwEfBHvc5"
262+
},
263+
"cell_type": "markdown",
264+
"source": [
265+
"## Import Python Libraries"
266+
]
267+
},
268+
{
269+
"metadata": {
270+
"id": "vVwZyXA7HsdJ"
271+
},
272+
"cell_type": "code",
273+
"source": [
274+
"import onnx\n",
275+
"import onnxruntime as ort\n",
276+
"\n",
277+
"from PIL import Image\n",
278+
"import supervision as sv\n",
279+
"\n",
280+
"import sys\n",
281+
"from unittest.mock import MagicMock\n",
282+
"sys.modules[\"color_extraction\"] = MagicMock()\n",
283+
"\n",
284+
"from models.official.projects.waste_identification_ml.Deploy.detr_cloud_deployment.client.triton_server_inference import TritonObjectDetector\n",
285+
"from models.official.projects.waste_identification_ml.Deploy.detr_cloud_deployment.client.utils import draw_detections_and_save_image"
286+
],
287+
"outputs": [],
288+
"execution_count": null
289+
},
290+
{
291+
"metadata": {
292+
"id": "FV1DL5SDH9oP"
293+
},
294+
"cell_type": "markdown",
295+
"source": [
296+
"## Define essential variables"
297+
]
298+
},
299+
{
300+
"metadata": {
301+
"id": "OvBh1HCnHsgX"
302+
},
303+
"cell_type": "code",
304+
"source": [
305+
"img_path = \"./sample_image.jpg\"\n",
306+
"onnx_model_path = \"./inference_model.onnx\"\n",
307+
"output_image_dimensions = (1024, 1024)"
308+
],
309+
"outputs": [],
310+
"execution_count": null
311+
},
312+
{
313+
"metadata": {
314+
"id": "8xeC6b4xIPbL"
315+
},
316+
"cell_type": "markdown",
317+
"source": [
318+
"## Initialize ONNX version of the model"
319+
]
320+
},
321+
{
322+
"metadata": {
323+
"id": "sYcXoly3IL99"
324+
},
325+
"cell_type": "code",
326+
"source": [
327+
"model_utils = TritonObjectDetector()\n",
328+
"onnx_model = ort.InferenceSession(onnx_model_path)"
329+
],
330+
"outputs": [],
331+
"execution_count": null
332+
},
333+
{
334+
"metadata": {
335+
"id": "2L0oJEtaIUgd"
336+
},
337+
"cell_type": "markdown",
338+
"source": [
339+
"## Do model Inferencing"
340+
]
341+
},
342+
{
343+
"metadata": {
344+
"id": "lGHtqGMyISTW"
345+
},
346+
"cell_type": "code",
347+
"source": [
348+
"image_array = model_utils._get_input_batch_for_inference(image_path=img_path)\n",
349+
"outputs = onnx_model_session.run(None, {\"input\": image_array})"
350+
],
351+
"outputs": [],
352+
"execution_count": null
353+
},
354+
{
355+
"metadata": {
356+
"id": "GdnE_BJtIZVt"
357+
},
358+
"cell_type": "markdown",
359+
"source": [
360+
"## Format the model output as per output dimensions"
361+
]
362+
},
363+
{
364+
"metadata": {
365+
"id": "POQajqV7IeGZ"
366+
},
367+
"cell_type": "markdown",
368+
"source": [
369+
"results = model_utils._reformat_triton_output_to_dict(\n",
370+
" outputs, confidence_threshold=0.5, max_boxes=100\n",
371+
")\n",
372+
"\n",
373+
"results = model_utils._scale_bbox_and_masks(results, target_dims=(1920, 1080))\n",
374+
"results[\"class_names\"] = model_utils.get_class_names(results)"
375+
]
376+
},
377+
{
378+
"metadata": {
379+
"id": "tf1XzX-5Ihij"
380+
},
381+
"cell_type": "markdown",
382+
"source": [
383+
"## Save output and Visualize Results"
384+
]
385+
},
386+
{
387+
"metadata": {
388+
"id": "dwHzAAwjIjbu"
389+
},
390+
"cell_type": "markdown",
391+
"source": [
392+
"draw_detections_and_save_image(img=Image.open(img_path), results=results, save_path= \"./output.jpg\")\n",
393+
"Image.open(\"./output.jpg\")"
394+
]
395+
},
396+
{
397+
"metadata": {
398+
"id": "TA529zr_Imwg"
399+
},
400+
"cell_type": "markdown",
401+
"source": [
402+
"# END of Notebook"
403+
]
404+
},
405+
{
406+
"metadata": {
407+
"id": "1GHOytpUInpI"
408+
},
409+
"cell_type": "markdown",
410+
"source": []
194411
}
195412
],
196413
"metadata": {

0 commit comments

Comments
 (0)