Skip to content

Commit 9ec77df

Browse files
author
xhlulu
committed
Add dash dino app
Former-commit-id: a9bde48
1 parent ead0151 commit 9ec77df

6 files changed

Lines changed: 311 additions & 1 deletion

File tree

apps/dash-deit/Procfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
web: gunicorn app:server --workers 2
1+
web: gunicorn app:server --workers 4

apps/dash-dino/.gitignore

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
flask_cache
2+
.vscode
3+
4+
# Byte-compiled / optimized / DLL files
5+
__pycache__/
6+
*.py[cod]
7+
*$py.class
8+
9+
# C extensions
10+
*.so
11+
12+
# Distribution / packaging
13+
.Python
14+
build/
15+
develop-eggs/
16+
dist/
17+
downloads/
18+
eggs/
19+
.eggs/
20+
lib/
21+
lib64/
22+
parts/
23+
sdist/
24+
var/
25+
wheels/
26+
share/python-wheels/
27+
*.egg-info/
28+
.installed.cfg
29+
*.egg
30+
MANIFEST
31+
32+
# PyInstaller
33+
# Usually these files are written by a python script from a template
34+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
35+
*.manifest
36+
*.spec
37+
38+
# Installer logs
39+
pip-log.txt
40+
pip-delete-this-directory.txt
41+
42+
# Unit test / coverage reports
43+
htmlcov/
44+
.tox/
45+
.nox/
46+
.coverage
47+
.coverage.*
48+
.cache
49+
nosetests.xml
50+
coverage.xml
51+
*.cover
52+
*.py,cover
53+
.hypothesis/
54+
.pytest_cache/
55+
cover/
56+
57+
# Translations
58+
*.mo
59+
*.pot
60+
61+
# Django stuff:
62+
*.log
63+
local_settings.py
64+
db.sqlite3
65+
db.sqlite3-journal
66+
67+
# Flask stuff:
68+
instance/
69+
.webassets-cache
70+
71+
# Scrapy stuff:
72+
.scrapy
73+
74+
# Sphinx documentation
75+
docs/_build/
76+
77+
# PyBuilder
78+
.pybuilder/
79+
target/
80+
81+
# Jupyter Notebook
82+
.ipynb_checkpoints
83+
84+
# IPython
85+
profile_default/
86+
ipython_config.py
87+
88+
# pyenv
89+
# For a library or package, you might want to ignore these files since the code is
90+
# intended to run in multiple environments; otherwise, check them in:
91+
# .python-version
92+
93+
# pipenv
94+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
96+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
97+
# install all needed dependencies.
98+
#Pipfile.lock
99+
100+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
101+
__pypackages__/
102+
103+
# Celery stuff
104+
celerybeat-schedule
105+
celerybeat.pid
106+
107+
# SageMath parsed files
108+
*.sage.py
109+
110+
# Environments
111+
.env
112+
.venv
113+
env/
114+
venv/
115+
ENV/
116+
env.bak/
117+
venv.bak/
118+
119+
# Spyder project settings
120+
.spyderproject
121+
.spyproject
122+
123+
# Rope project settings
124+
.ropeproject
125+
126+
# mkdocs documentation
127+
/site
128+
129+
# mypy
130+
.mypy_cache/
131+
.dmypy.json
132+
dmypy.json
133+
134+
# Pyre type checker
135+
.pyre/
136+
137+
# pytype static type analyzer
138+
.pytype/
139+
140+
# Cython debug symbols
141+
cython_debug/

apps/dash-dino/Procfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
web: gunicorn app:server --workers 2

apps/dash-dino/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Dash DINO
2+
3+
This is a demo using [Facebook AI's DINO](https://github.com/facebookresearch/dino) model and [Dash Labs](https://github.com/plotly/dash-labs).

apps/dash-dino/app.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import time
2+
from functools import partial
3+
4+
import dash
5+
import dash_bootstrap_components as dbc
6+
import dash_labs as dl
7+
import plotly.express as px
8+
import plotly.graph_objs as go
9+
from plotly.subplots import make_subplots
10+
import requests
11+
import torch
12+
import torch.nn as nn
13+
from torchvision import transforms as pth_transforms
14+
from PIL import Image
15+
import numpy as np
16+
from flask_caching import Cache
17+
18+
19+
def download_img(url, size=(500, 500)):
20+
im = Image.open(requests.get(url, stream=True).raw).convert("RGB")
21+
im.thumbnail(size)
22+
return im
23+
24+
25+
def compute_attentions(model, patch_size=16):
26+
def aux(img):
27+
"""
28+
Source: https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
29+
"""
30+
# make the image divisible by the patch size
31+
w, h = (
32+
img.shape[1] - img.shape[1] % patch_size,
33+
img.shape[2] - img.shape[2] % patch_size,
34+
)
35+
img = img[:, :w, :h].unsqueeze(0)
36+
w_featmap = img.shape[-2] // patch_size
37+
h_featmap = img.shape[-1] // patch_size
38+
attentions = model.forward_selfattention(img)
39+
40+
return attentions, w_featmap, h_featmap
41+
42+
return aux
43+
44+
45+
def apply_threshold(attentions, w_featmap, h_featmap, threshold, patch_size=16):
46+
nh = attentions.shape[1] # number of head
47+
# we keep only the output patch attention
48+
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
49+
# we keep only a certain percentage of the mass
50+
val, idx = torch.sort(attentions)
51+
val /= torch.sum(val, dim=1, keepdim=True)
52+
cumval = torch.cumsum(val, dim=1)
53+
th_attn = cumval > (1 - threshold)
54+
idx2 = torch.argsort(idx)
55+
for head in range(nh):
56+
th_attn[head] = th_attn[head][idx2[head]]
57+
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
58+
# interpolate
59+
th_attn = (
60+
nn.functional.interpolate(
61+
th_attn.unsqueeze(0), scale_factor=patch_size, mode="nearest"
62+
)[0]
63+
.detach()
64+
.cpu()
65+
.numpy()
66+
)
67+
attentions = attentions.reshape(nh, w_featmap, h_featmap)
68+
attentions = nn.functional.interpolate(
69+
attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest"
70+
)
71+
attentions = attentions[0].detach().cpu().numpy()
72+
73+
return th_attn, attentions
74+
75+
76+
# vars
77+
default_url = "https://dl.fbaipublicfiles.com/dino/img.png"
78+
79+
# Load model
80+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
81+
print("Running on", device)
82+
model = torch.hub.load("facebookresearch/dino:main", "dino_deits16").to(device)
83+
transform = pth_transforms.Compose(
84+
[
85+
pth_transforms.ToTensor(),
86+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
87+
]
88+
)
89+
90+
# Initialize dash app and dash-labs template
91+
app = dash.Dash(__name__, plugins=[dl.plugins.FlexibleCallbacks()])
92+
cache = Cache(
93+
app.server,
94+
config={
95+
# try 'filesystem' if you don't want to setup redis
96+
"CACHE_TYPE": "filesystem",
97+
"CACHE_DIR": "flask_cache",
98+
},
99+
)
100+
tpl = dl.templates.DbcSidebar(
101+
title="DINO Demo (using Dash Labs)", theme=dbc.themes.DARKLY,
102+
)
103+
104+
# memoize functions
105+
predict = cache.memoize(timeout=300)(compute_attentions(model))
106+
download_img = cache.memoize(timeout=300)(download_img)
107+
108+
# Define callback function
109+
@app.callback(
110+
args=dict(
111+
url=tpl.textbox_input(default_url, label="Image URL", kind=dl.State),
112+
run=tpl.button_input("Run", label=""),
113+
head=tpl.dropdown_input(list(range(6)), value="0", label="Attention Head"),
114+
options=tpl.checklist_input(["use threshold", "overlay"], []),
115+
threshold=tpl.slider_input(0, 1, 0.6, 0.01),
116+
),
117+
output=tpl.graph_output(),
118+
template=tpl,
119+
)
120+
def callback(url, run, threshold, head, options):
121+
try:
122+
im = download_img(url)
123+
except:
124+
return go.Figure().update_layout(title="Incorrect URL")
125+
126+
ix = int(head)
127+
128+
# Run model
129+
img = transform(im).to(device)
130+
attentions, w_featmap, h_featmap = predict(img)
131+
th_attn, scalar_attn = apply_threshold(attentions, w_featmap, h_featmap, threshold)
132+
133+
if "use threshold" in options:
134+
attns = th_attn
135+
else:
136+
attns = scalar_attn
137+
138+
if "overlay" in options:
139+
fig = px.imshow(im)
140+
fig.add_trace(go.Heatmap(z=attns[ix], opacity=0.55))
141+
else:
142+
fig = make_subplots(1, 2)
143+
fig.add_trace(go.Image(z=im), 1, 1)
144+
fig.add_trace(go.Heatmap(z=attns[ix], y=np.arange(attns.shape[1], 0, -1)), 1, 2)
145+
fig.update_xaxes(matches="x")
146+
147+
return fig
148+
149+
150+
app.layout = tpl.layout(app)
151+
152+
153+
if __name__ == "__main__":
154+
app.run_server(debug=True)

apps/dash-dino/requirements.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
torch
2+
dash-labs==0.1.0
3+
dash-bootstrap-components
4+
spectra
5+
colormath
6+
requests
7+
tinycss2
8+
pandas
9+
torchvision
10+
Pillow
11+
Flask-Caching

0 commit comments

Comments
 (0)