Skip to content

Commit 069846d

Browse files
author
Xing Han Lu
authored
Merge pull request #625 from plotly/add-dino
Add dash dino app (#minor) Former-commit-id: ffb717b
2 parents ead0151 + eb18751 commit 069846d

6 files changed

Lines changed: 340 additions & 0 deletions

File tree

apps/dash-dino/.gitignore

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
flask_cache
2+
.vscode
3+
4+
# Byte-compiled / optimized / DLL files
5+
__pycache__/
6+
*.py[cod]
7+
*$py.class
8+
9+
# C extensions
10+
*.so
11+
12+
# Distribution / packaging
13+
.Python
14+
build/
15+
develop-eggs/
16+
dist/
17+
downloads/
18+
eggs/
19+
.eggs/
20+
lib/
21+
lib64/
22+
parts/
23+
sdist/
24+
var/
25+
wheels/
26+
share/python-wheels/
27+
*.egg-info/
28+
.installed.cfg
29+
*.egg
30+
MANIFEST
31+
32+
# PyInstaller
33+
# Usually these files are written by a python script from a template
34+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
35+
*.manifest
36+
*.spec
37+
38+
# Installer logs
39+
pip-log.txt
40+
pip-delete-this-directory.txt
41+
42+
# Unit test / coverage reports
43+
htmlcov/
44+
.tox/
45+
.nox/
46+
.coverage
47+
.coverage.*
48+
.cache
49+
nosetests.xml
50+
coverage.xml
51+
*.cover
52+
*.py,cover
53+
.hypothesis/
54+
.pytest_cache/
55+
cover/
56+
57+
# Translations
58+
*.mo
59+
*.pot
60+
61+
# Django stuff:
62+
*.log
63+
local_settings.py
64+
db.sqlite3
65+
db.sqlite3-journal
66+
67+
# Flask stuff:
68+
instance/
69+
.webassets-cache
70+
71+
# Scrapy stuff:
72+
.scrapy
73+
74+
# Sphinx documentation
75+
docs/_build/
76+
77+
# PyBuilder
78+
.pybuilder/
79+
target/
80+
81+
# Jupyter Notebook
82+
.ipynb_checkpoints
83+
84+
# IPython
85+
profile_default/
86+
ipython_config.py
87+
88+
# pyenv
89+
# For a library or package, you might want to ignore these files since the code is
90+
# intended to run in multiple environments; otherwise, check them in:
91+
# .python-version
92+
93+
# pipenv
94+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
96+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
97+
# install all needed dependencies.
98+
#Pipfile.lock
99+
100+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
101+
__pypackages__/
102+
103+
# Celery stuff
104+
celerybeat-schedule
105+
celerybeat.pid
106+
107+
# SageMath parsed files
108+
*.sage.py
109+
110+
# Environments
111+
.env
112+
.venv
113+
env/
114+
venv/
115+
ENV/
116+
env.bak/
117+
venv.bak/
118+
119+
# Spyder project settings
120+
.spyderproject
121+
.spyproject
122+
123+
# Rope project settings
124+
.ropeproject
125+
126+
# mkdocs documentation
127+
/site
128+
129+
# mypy
130+
.mypy_cache/
131+
.dmypy.json
132+
dmypy.json
133+
134+
# Pyre type checker
135+
.pyre/
136+
137+
# pytype static type analyzer
138+
.pytype/
139+
140+
# Cython debug symbols
141+
cython_debug/

apps/dash-dino/Procfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
web: gunicorn app:server --workers 4

apps/dash-dino/README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Dash DINO
2+
3+
This is a demo of [Facebook AI's DINO](https://github.com/facebookresearch/dino) model, built using [Dash Labs](https://github.com/plotly/dash-labs).
4+
5+
![](./demo.gif)
6+
7+
Using Dash Labs, you can build apps without specifying a layout. This app was built using this single function:
8+
9+
```python
10+
@app.callback(
11+
args=dict(
12+
url=tpl.textbox_input(default_url, label="Image URL", kind=dl.State),
13+
run=tpl.button_input("Run", label=""),
14+
head=tpl.dropdown_input(list(range(6)), value="0", label="Attention Head"),
15+
options=tpl.checklist_input(["use threshold", "overlay"], []),
16+
threshold=tpl.slider_input(0, 1, 0.6, 0.01),
17+
),
18+
output=tpl.graph_output(),
19+
template=tpl,
20+
)
21+
def callback(url, run, threshold, head, options):
22+
try:
23+
im = download_img(url)
24+
except:
25+
return go.Figure().update_layout(title="Incorrect URL")
26+
27+
ix = int(head)
28+
# Run model
29+
img = transform(im).to(device)
30+
attentions, w_featmap, h_featmap = predict(img)
31+
th_attn, scalar_attn = apply_threshold(attentions, w_featmap, h_featmap, threshold)
32+
33+
attns = th_attn if "use threshold" in options else scalar_attn
34+
35+
if "overlay" in options:
36+
fig = px.imshow(im)
37+
fig.add_trace(go.Heatmap(z=attns[ix], opacity=0.55))
38+
else:
39+
fig = make_subplots(1, 2)
40+
fig.add_trace(go.Image(z=im), 1, 1)
41+
fig.add_trace(go.Heatmap(z=attns[ix]), 1, 2)
42+
fig.update_xaxes(matches="x")
43+
fig.update_yaxes(matches="y")
44+
45+
return fig
46+
```
47+
48+
The entire layout was built from the args specified in the `app.callback` thanks to [templates](https://community.plotly.com/t/introducing-dash-labs-dash-2-0-preview/52087).

apps/dash-dino/app.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import dash
2+
import dash_bootstrap_components as dbc
3+
import dash_labs as dl
4+
import plotly.express as px
5+
import plotly.graph_objs as go
6+
from plotly.subplots import make_subplots
7+
import requests
8+
import torch
9+
import torch.nn as nn
10+
from torchvision import transforms as pth_transforms
11+
from PIL import Image
12+
from flask_caching import Cache
13+
14+
15+
def download_img(url, size=(600, 600)):
16+
im = Image.open(requests.get(url, stream=True).raw).convert("RGB")
17+
im.thumbnail(size)
18+
return im
19+
20+
21+
# Source: https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
22+
def compute_attentions(model, patch_size=16):
23+
def aux(img):
24+
# make the image divisible by the patch size
25+
w, h = (
26+
img.shape[1] - img.shape[1] % patch_size,
27+
img.shape[2] - img.shape[2] % patch_size,
28+
)
29+
img = img[:, :w, :h].unsqueeze(0)
30+
w_featmap = img.shape[-2] // patch_size
31+
h_featmap = img.shape[-1] // patch_size
32+
attentions = model.forward_selfattention(img)
33+
34+
return attentions, w_featmap, h_featmap
35+
36+
return aux
37+
38+
39+
# Source: https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
40+
def apply_threshold(attentions, w_featmap, h_featmap, threshold, patch_size=16):
41+
nh = attentions.shape[1] # number of head
42+
# we keep only the output patch attention
43+
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
44+
# we keep only a certain percentage of the mass
45+
val, idx = torch.sort(attentions)
46+
val /= torch.sum(val, dim=1, keepdim=True)
47+
cumval = torch.cumsum(val, dim=1)
48+
th_attn = cumval > (1 - threshold)
49+
idx2 = torch.argsort(idx)
50+
for head in range(nh):
51+
th_attn[head] = th_attn[head][idx2[head]]
52+
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
53+
th_attn = nn.functional.interpolate(
54+
th_attn.unsqueeze(0), scale_factor=patch_size, mode="nearest"
55+
)
56+
th_attn = th_attn[0].detach().cpu().numpy()
57+
58+
attentions = attentions.reshape(nh, w_featmap, h_featmap)
59+
attentions = nn.functional.interpolate(
60+
attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest"
61+
)
62+
attentions = attentions[0].detach().cpu().numpy()
63+
64+
return th_attn, attentions
65+
66+
67+
# VARS
68+
default_url = "https://dl.fbaipublicfiles.com/dino/img.png"
69+
70+
# Load model
71+
torch.hub.set_dir("./")
72+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
73+
print("Running on", device)
74+
model = torch.hub.load("facebookresearch/dino:main", "dino_deits16").to(device)
75+
transform = pth_transforms.Compose(
76+
[
77+
pth_transforms.ToTensor(),
78+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
79+
]
80+
)
81+
82+
# Initialize dash app and dash-labs template
83+
title = "Zero-shot segmentation with DINO and Dash Labs"
84+
app = dash.Dash(__name__, title=title, plugins=[dl.plugins.FlexibleCallbacks()])
85+
server = app.server
86+
tpl = dl.templates.DbcSidebar(title=title, theme=dbc.themes.DARKLY)
87+
cache = Cache(
88+
app.server, config={"CACHE_TYPE": "filesystem", "CACHE_DIR": "flask_cache"},
89+
)
90+
91+
# memoize functions
92+
predict = cache.memoize(timeout=300)(compute_attentions(model))
93+
download_img = cache.memoize(timeout=300)(download_img)
94+
95+
# Define callback function
96+
@app.callback(
97+
args=dict(
98+
url=tpl.textbox_input(default_url, label="Image URL", kind=dl.State),
99+
run=tpl.button_input("Run", label=""),
100+
head=tpl.dropdown_input(list(range(6)), value="0", label="Attention Head"),
101+
options=tpl.checklist_input(["use threshold", "overlay"], []),
102+
threshold=tpl.slider_input(0, 1, 0.6, 0.01),
103+
),
104+
output=tpl.graph_output(),
105+
template=tpl,
106+
)
107+
def callback(url, run, threshold, head, options):
108+
try:
109+
im = download_img(url)
110+
except:
111+
return go.Figure().update_layout(title="Incorrect URL")
112+
113+
ix = int(head)
114+
# Run model
115+
img = transform(im).to(device)
116+
attentions, w_featmap, h_featmap = predict(img)
117+
th_attn, scalar_attn = apply_threshold(attentions, w_featmap, h_featmap, threshold)
118+
119+
attns = th_attn if "use threshold" in options else scalar_attn
120+
121+
if "overlay" in options:
122+
fig = px.imshow(im)
123+
fig.add_trace(go.Heatmap(z=attns[ix], opacity=0.55))
124+
else:
125+
fig = make_subplots(1, 2)
126+
fig.add_trace(go.Image(z=im), 1, 1)
127+
fig.add_trace(go.Heatmap(z=attns[ix]), 1, 2)
128+
fig.update_xaxes(matches="x")
129+
fig.update_yaxes(matches="y")
130+
131+
return fig
132+
133+
134+
app.layout = tpl.layout(app)
135+
136+
137+
if __name__ == "__main__":
138+
app.run_server(debug=True)

apps/dash-dino/demo.gif

4.61 MB
Loading

apps/dash-dino/requirements.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
torch
2+
dash-labs==0.1.0
3+
dash-bootstrap-components
4+
spectra
5+
colormath
6+
requests
7+
tinycss2
8+
pandas
9+
torchvision
10+
Pillow
11+
Flask-Caching
12+
gunicorn

0 commit comments

Comments
 (0)