1- import time
2- from functools import partial
3-
41import dash
52import dash_bootstrap_components as dbc
63import dash_labs as dl
129import torch .nn as nn
1310from torchvision import transforms as pth_transforms
1411from PIL import Image
15- import numpy as np
1612from flask_caching import Cache
1713
1814
19- def download_img (url , size = (500 , 500 )):
15+ def download_img (url , size = (600 , 600 )):
2016 im = Image .open (requests .get (url , stream = True ).raw ).convert ("RGB" )
2117 im .thumbnail (size )
2218 return im
2319
2420
21+ # Source: https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
2522def compute_attentions (model , patch_size = 16 ):
2623 def aux (img ):
27- """
28- Source: https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
29- """
3024 # make the image divisible by the patch size
3125 w , h = (
3226 img .shape [1 ] - img .shape [1 ] % patch_size ,
@@ -42,6 +36,7 @@ def aux(img):
4236 return aux
4337
4438
39+ # Source: https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
4540def apply_threshold (attentions , w_featmap , h_featmap , threshold , patch_size = 16 ):
4641 nh = attentions .shape [1 ] # number of head
4742 # we keep only the output patch attention
@@ -55,15 +50,11 @@ def apply_threshold(attentions, w_featmap, h_featmap, threshold, patch_size=16):
5550 for head in range (nh ):
5651 th_attn [head ] = th_attn [head ][idx2 [head ]]
5752 th_attn = th_attn .reshape (nh , w_featmap , h_featmap ).float ()
58- # interpolate
59- th_attn = (
60- nn .functional .interpolate (
61- th_attn .unsqueeze (0 ), scale_factor = patch_size , mode = "nearest"
62- )[0 ]
63- .detach ()
64- .cpu ()
65- .numpy ()
53+ th_attn = nn .functional .interpolate (
54+ th_attn .unsqueeze (0 ), scale_factor = patch_size , mode = "nearest"
6655 )
56+ th_attn = th_attn [0 ].detach ().cpu ().numpy ()
57+
6758 attentions = attentions .reshape (nh , w_featmap , h_featmap )
6859 attentions = nn .functional .interpolate (
6960 attentions .unsqueeze (0 ), scale_factor = patch_size , mode = "nearest"
@@ -73,7 +64,7 @@ def apply_threshold(attentions, w_featmap, h_featmap, threshold, patch_size=16):
7364 return th_attn , attentions
7465
7566
76- # vars
67+ # VARS
7768default_url = "https://dl.fbaipublicfiles.com/dino/img.png"
7869
7970# Load model
@@ -88,17 +79,11 @@ def apply_threshold(attentions, w_featmap, h_featmap, threshold, patch_size=16):
8879)
8980
9081# Initialize dash app and dash-labs template
91- app = dash .Dash (__name__ , plugins = [dl .plugins .FlexibleCallbacks ()])
82+ title = "Zero-shot segmentation with DINO and Dash Labs"
83+ app = dash .Dash (__name__ , title = title , plugins = [dl .plugins .FlexibleCallbacks ()])
84+ tpl = dl .templates .DbcSidebar (title = title , theme = dbc .themes .DARKLY )
9285cache = Cache (
93- app .server ,
94- config = {
95- # try 'filesystem' if you don't want to setup redis
96- "CACHE_TYPE" : "filesystem" ,
97- "CACHE_DIR" : "flask_cache" ,
98- },
99- )
100- tpl = dl .templates .DbcSidebar (
101- title = "DINO Demo (using Dash Labs)" , theme = dbc .themes .DARKLY ,
86+ app .server , config = {"CACHE_TYPE" : "filesystem" , "CACHE_DIR" : "flask_cache" },
10287)
10388
10489# memoize functions
@@ -124,25 +109,22 @@ def callback(url, run, threshold, head, options):
124109 return go .Figure ().update_layout (title = "Incorrect URL" )
125110
126111 ix = int (head )
127-
128112 # Run model
129113 img = transform (im ).to (device )
130114 attentions , w_featmap , h_featmap = predict (img )
131115 th_attn , scalar_attn = apply_threshold (attentions , w_featmap , h_featmap , threshold )
132116
133- if "use threshold" in options :
134- attns = th_attn
135- else :
136- attns = scalar_attn
117+ attns = th_attn if "use threshold" in options else scalar_attn
137118
138119 if "overlay" in options :
139120 fig = px .imshow (im )
140121 fig .add_trace (go .Heatmap (z = attns [ix ], opacity = 0.55 ))
141122 else :
142123 fig = make_subplots (1 , 2 )
143124 fig .add_trace (go .Image (z = im ), 1 , 1 )
144- fig .add_trace (go .Heatmap (z = attns [ix ], y = np . arange ( attns . shape [ 1 ], 0 , - 1 ) ), 1 , 2 )
125+ fig .add_trace (go .Heatmap (z = attns [ix ]), 1 , 2 )
145126 fig .update_xaxes (matches = "x" )
127+ fig .update_yaxes (matches = "y" )
146128
147129 return fig
148130
0 commit comments