Skip to content

Commit 075267c

Browse files
author
xhlulu
committed
Update readme, update app title, simplify app
Former-commit-id: 4a407d2
1 parent 9ec77df commit 075267c

3 files changed

Lines changed: 61 additions & 34 deletions

File tree

apps/dash-dino/README.md

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,48 @@
11
# Dash DINO
22

3-
This is a demo using [Facebook AI's DINO](https://github.com/facebookresearch/dino) model and [Dash Labs](https://github.com/plotly/dash-labs).
3+
This is a demo of [Facebook AI's DINO](https://github.com/facebookresearch/dino) model, built using [Dash Labs](https://github.com/plotly/dash-labs).
4+
5+
![](./demo.gif)
6+
7+
Using Dash Labs, you can build apps without specifying a layout. This app was built using this single function:
8+
9+
```python
10+
@app.callback(
11+
args=dict(
12+
url=tpl.textbox_input(default_url, label="Image URL", kind=dl.State),
13+
run=tpl.button_input("Run", label=""),
14+
head=tpl.dropdown_input(list(range(6)), value="0", label="Attention Head"),
15+
options=tpl.checklist_input(["use threshold", "overlay"], []),
16+
threshold=tpl.slider_input(0, 1, 0.6, 0.01),
17+
),
18+
output=tpl.graph_output(),
19+
template=tpl,
20+
)
21+
def callback(url, run, threshold, head, options):
22+
try:
23+
im = download_img(url)
24+
except:
25+
return go.Figure().update_layout(title="Incorrect URL")
26+
27+
ix = int(head)
28+
# Run model
29+
img = transform(im).to(device)
30+
attentions, w_featmap, h_featmap = predict(img)
31+
th_attn, scalar_attn = apply_threshold(attentions, w_featmap, h_featmap, threshold)
32+
33+
attns = th_attn if "use threshold" in options else scalar_attn
34+
35+
if "overlay" in options:
36+
fig = px.imshow(im)
37+
fig.add_trace(go.Heatmap(z=attns[ix], opacity=0.55))
38+
else:
39+
fig = make_subplots(1, 2)
40+
fig.add_trace(go.Image(z=im), 1, 1)
41+
fig.add_trace(go.Heatmap(z=attns[ix]), 1, 2)
42+
fig.update_xaxes(matches="x")
43+
fig.update_yaxes(matches="y")
44+
45+
return fig
46+
```
47+
48+
The entire layout was built from the args specified in the `app.callback` thanks to [templates](https://community.plotly.com/t/introducing-dash-labs-dash-2-0-preview/52087).

apps/dash-dino/app.py

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import time
2-
from functools import partial
3-
41
import dash
52
import dash_bootstrap_components as dbc
63
import dash_labs as dl
@@ -12,21 +9,18 @@
129
import torch.nn as nn
1310
from torchvision import transforms as pth_transforms
1411
from PIL import Image
15-
import numpy as np
1612
from flask_caching import Cache
1713

1814

19-
def download_img(url, size=(500, 500)):
15+
def download_img(url, size=(600, 600)):
2016
im = Image.open(requests.get(url, stream=True).raw).convert("RGB")
2117
im.thumbnail(size)
2218
return im
2319

2420

21+
# Source: https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
2522
def compute_attentions(model, patch_size=16):
2623
def aux(img):
27-
"""
28-
Source: https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
29-
"""
3024
# make the image divisible by the patch size
3125
w, h = (
3226
img.shape[1] - img.shape[1] % patch_size,
@@ -42,6 +36,7 @@ def aux(img):
4236
return aux
4337

4438

39+
# Source: https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
4540
def apply_threshold(attentions, w_featmap, h_featmap, threshold, patch_size=16):
4641
nh = attentions.shape[1] # number of head
4742
# we keep only the output patch attention
@@ -55,15 +50,11 @@ def apply_threshold(attentions, w_featmap, h_featmap, threshold, patch_size=16):
5550
for head in range(nh):
5651
th_attn[head] = th_attn[head][idx2[head]]
5752
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
58-
# interpolate
59-
th_attn = (
60-
nn.functional.interpolate(
61-
th_attn.unsqueeze(0), scale_factor=patch_size, mode="nearest"
62-
)[0]
63-
.detach()
64-
.cpu()
65-
.numpy()
53+
th_attn = nn.functional.interpolate(
54+
th_attn.unsqueeze(0), scale_factor=patch_size, mode="nearest"
6655
)
56+
th_attn = th_attn[0].detach().cpu().numpy()
57+
6758
attentions = attentions.reshape(nh, w_featmap, h_featmap)
6859
attentions = nn.functional.interpolate(
6960
attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest"
@@ -73,7 +64,7 @@ def apply_threshold(attentions, w_featmap, h_featmap, threshold, patch_size=16):
7364
return th_attn, attentions
7465

7566

76-
# vars
67+
# VARS
7768
default_url = "https://dl.fbaipublicfiles.com/dino/img.png"
7869

7970
# Load model
@@ -88,17 +79,11 @@ def apply_threshold(attentions, w_featmap, h_featmap, threshold, patch_size=16):
8879
)
8980

9081
# Initialize dash app and dash-labs template
91-
app = dash.Dash(__name__, plugins=[dl.plugins.FlexibleCallbacks()])
82+
title = "Zero-shot segmentation with DINO and Dash Labs"
83+
app = dash.Dash(__name__, title=title, plugins=[dl.plugins.FlexibleCallbacks()])
84+
tpl = dl.templates.DbcSidebar(title=title, theme=dbc.themes.DARKLY)
9285
cache = Cache(
93-
app.server,
94-
config={
95-
# try 'filesystem' if you don't want to setup redis
96-
"CACHE_TYPE": "filesystem",
97-
"CACHE_DIR": "flask_cache",
98-
},
99-
)
100-
tpl = dl.templates.DbcSidebar(
101-
title="DINO Demo (using Dash Labs)", theme=dbc.themes.DARKLY,
86+
app.server, config={"CACHE_TYPE": "filesystem", "CACHE_DIR": "flask_cache"},
10287
)
10388

10489
# memoize functions
@@ -124,25 +109,22 @@ def callback(url, run, threshold, head, options):
124109
return go.Figure().update_layout(title="Incorrect URL")
125110

126111
ix = int(head)
127-
128112
# Run model
129113
img = transform(im).to(device)
130114
attentions, w_featmap, h_featmap = predict(img)
131115
th_attn, scalar_attn = apply_threshold(attentions, w_featmap, h_featmap, threshold)
132116

133-
if "use threshold" in options:
134-
attns = th_attn
135-
else:
136-
attns = scalar_attn
117+
attns = th_attn if "use threshold" in options else scalar_attn
137118

138119
if "overlay" in options:
139120
fig = px.imshow(im)
140121
fig.add_trace(go.Heatmap(z=attns[ix], opacity=0.55))
141122
else:
142123
fig = make_subplots(1, 2)
143124
fig.add_trace(go.Image(z=im), 1, 1)
144-
fig.add_trace(go.Heatmap(z=attns[ix], y=np.arange(attns.shape[1], 0, -1)), 1, 2)
125+
fig.add_trace(go.Heatmap(z=attns[ix]), 1, 2)
145126
fig.update_xaxes(matches="x")
127+
fig.update_yaxes(matches="y")
146128

147129
return fig
148130

apps/dash-dino/demo.gif

4.61 MB
Loading

0 commit comments

Comments
 (0)