Skip to content

Commit a1c10f4

Browse files
committed
Use ruff for formatting + Fix minor issues
1 parent 334a970 commit a1c10f4

31 files changed

Lines changed: 239 additions & 666 deletions

.flake8

Lines changed: 0 additions & 5 deletions
This file was deleted.

.pre-commit-config.yaml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,14 @@ repos:
2020
- id: mixed-line-ending
2121
- id: trailing-whitespace
2222
files: \.(py|sh|rst|yml|yaml)$
23-
- repo: https://github.com/psf/black
24-
rev: 23.3.0
23+
- repo: https://github.com/astral-sh/ruff-pre-commit
24+
rev: v0.11.7
2525
hooks:
26-
- id: black
26+
- id: ruff
27+
args: [ --fix ]
2728
exclude: tests/|anylabeling/resources/resources.py
28-
- repo: https://github.com/psf/black
29-
rev: 23.3.0
30-
hooks:
31-
- id: black
29+
- id: ruff-format
3230
exclude: tests/|anylabeling/resources/resources.py
33-
args: [--check]
3431
- repo: https://github.com/rstcheck/rstcheck
3532
rev: v6.1.2
3633
hooks:

anylabeling/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from .app_info import __appdescription__, __appname__, __version__
2+
3+
__all__ = ["__appdescription__", "__appname__", "__version__"]

anylabeling/app.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
from anylabeling.views.labeling.utils import new_icon
2424
from anylabeling.resources import resources
2525

26+
__all__ = ["resources"]
27+
2628

2729
def main():
2830
parser = argparse.ArgumentParser()
29-
parser.add_argument(
30-
"--reset-config", action="store_true", help="reset qt config"
31-
)
31+
parser.add_argument("--reset-config", action="store_true", help="reset qt config")
3232
parser.add_argument(
3333
"--logger-level",
3434
default="info",
@@ -45,16 +45,11 @@ def main():
4545
"recognized as file, else as directory)"
4646
),
4747
)
48-
default_config_file = os.path.join(
49-
os.path.expanduser("~"), ".anylabelingrc"
50-
)
48+
default_config_file = os.path.join(os.path.expanduser("~"), ".anylabelingrc")
5149
parser.add_argument(
5250
"--config",
5351
dest="config",
54-
help=(
55-
"config file or yaml-format string (default:"
56-
f" {default_config_file})"
57-
),
52+
help=(f"config file or yaml-format string (default: {default_config_file})"),
5853
default=default_config_file,
5954
)
6055
# config for the gui
@@ -167,9 +162,7 @@ def main():
167162

168163
language = config.get("language", QtCore.QLocale.system().name())
169164
translator = QtCore.QTranslator()
170-
loaded_language = translator.load(
171-
":/languages/translations/" + language + ".qm"
172-
)
165+
loaded_language = translator.load(":/languages/translations/" + language + ".qm")
173166

174167
# Enable scaling for high dpi screens
175168
QtWidgets.QApplication.setAttribute(

anylabeling/config.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,11 @@ def get_default_config():
5454

5555
def validate_config_item(key, value):
5656
if key == "validate_label" and value not in [None, "exact"]:
57-
raise ValueError(
58-
f"Unexpected value for config key 'validate_label': {value}"
59-
)
57+
raise ValueError(f"Unexpected value for config key 'validate_label': {value}")
6058
if key == "shape_color" and value not in [None, "auto", "manual"]:
61-
raise ValueError(
62-
f"Unexpected value for config key 'shape_color': {value}"
63-
)
59+
raise ValueError(f"Unexpected value for config key 'shape_color': {value}")
6460
if key == "labels" and value is not None and len(value) != len(set(value)):
65-
raise ValueError(
66-
f"Duplicates are detected for config key 'labels': {value}"
67-
)
61+
raise ValueError(f"Duplicates are detected for config key 'labels': {value}")
6862

6963

7064
def get_config(config_file_or_yaml=None, config_from_args=None):
@@ -80,14 +74,10 @@ def get_config(config_file_or_yaml=None, config_from_args=None):
8074
with open(config_from_yaml) as f:
8175
logger.info("Loading config file from: %s", config_from_yaml)
8276
config_from_yaml = yaml.safe_load(f)
83-
update_dict(
84-
config, config_from_yaml, validate_item=validate_config_item
85-
)
77+
update_dict(config, config_from_yaml, validate_item=validate_config_item)
8678

8779
# 3. command line argument or specified config file
8880
if config_from_args is not None:
89-
update_dict(
90-
config, config_from_args, validate_item=validate_config_item
91-
)
81+
update_dict(config, config_from_args, validate_item=validate_config_item)
9282

9383
return config

anylabeling/services/auto_labeling/model.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,29 @@
11
import logging
22
import os
3-
import pathlib
43
import yaml
5-
import onnx
6-
import urllib.request
7-
from urllib.parse import urlparse
8-
9-
from PyQt5.QtCore import QCoreApplication
10-
11-
import ssl
12-
13-
ssl._create_default_https_context = (
14-
ssl._create_unverified_context
15-
) # Prevent issue when downloading models behind a proxy
16-
174
import socket
18-
19-
socket.setdefaulttimeout(240) # Prevent timeout when downloading models
20-
5+
import ssl
216
from abc import abstractmethod
227

23-
24-
from PyQt5.QtCore import QFile, QObject
8+
from PyQt5.QtCore import QCoreApplication, QFile, QObject
259
from PyQt5.QtGui import QImage
2610

2711
from .types import AutoLabelingResult
2812
from anylabeling.views.labeling.label_file import LabelFile, LabelFileError
2913

14+
# Prevent issue when downloading models behind a proxy
15+
os.environ["no_proxy"] = "*"
16+
17+
socket.setdefaulttimeout(240) # Prevent timeout when downloading models
18+
19+
20+
ssl._create_default_https_context = (
21+
ssl._create_unverified_context
22+
) # Prevent issue when downloading models behind a proxy
23+
3024

3125
class Model(QObject):
32-
BASE_DOWNLOAD_URL = (
33-
"https://github.com/vietanhdev/anylabeling-assets/raw/main/"
34-
)
26+
BASE_DOWNLOAD_URL = "https://github.com/vietanhdev/anylabeling-assets/raw/main/"
3527

3628
class Meta(QObject):
3729
required_config_names = []
@@ -82,9 +74,7 @@ def get_model_abs_path(self, model_config, model_path_field_name):
8274
config_folder = os.path.dirname(model_config["config_file"])
8375
model_path = model_config[model_path_field_name]
8476
if os.path.isfile(os.path.join(config_folder, model_path)):
85-
model_abs_path = os.path.abspath(
86-
os.path.join(config_folder, model_path)
87-
)
77+
model_abs_path = os.path.abspath(os.path.join(config_folder, model_path))
8878
return model_abs_path
8979

9080
# Try getting model from assets folder

anylabeling/services/auto_labeling/model_manager.py

Lines changed: 24 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ def __init__(self):
6161
def load_model_configs(self):
6262
"""Load model configs"""
6363
# Load list of default models
64-
with pkg_resources.open_text(
65-
auto_labeling_configs, "models.yaml"
66-
) as f:
64+
with pkg_resources.open_text(auto_labeling_configs, "models.yaml") as f:
6765
model_list = yaml.safe_load(f)
6866
for model in model_list:
6967
model["is_custom_model"] = False
@@ -74,9 +72,7 @@ def load_model_configs(self):
7472
model_download_path = os.path.join(
7573
home_dir, "anylabeling_data", "models", model["name"]
7674
)
77-
pathlib.Path(model_download_path).mkdir(
78-
parents=True, exist_ok=True
79-
)
75+
pathlib.Path(model_download_path).mkdir(parents=True, exist_ok=True)
8076
config_file = os.path.join(model_download_path, "config.yaml")
8177
model["config_file"] = config_file
8278

@@ -126,9 +122,7 @@ def load_model_configs(self):
126122
if not model_config.get("is_custom_model", False):
127123
model_config["last_used"] = -i
128124
else:
129-
model_config["last_used"] = model_config.get(
130-
"last_used", time.time()
131-
)
125+
model_config["last_used"] = model_config.get("last_used", time.time())
132126
model_configs.sort(key=lambda x: x.get("last_used", 0), reverse=True)
133127

134128
self.model_configs = model_configs
@@ -147,9 +141,7 @@ def set_output_mode(self, mode):
147141
def on_model_download_finished(self):
148142
"""Handle model download thread finished"""
149143
if self.loaded_model_config and self.loaded_model_config["model"]:
150-
self.new_model_status.emit(
151-
self.tr("Model loaded. Ready for labeling.")
152-
)
144+
self.new_model_status.emit(self.tr("Model loaded. Ready for labeling."))
153145
self.model_loaded.emit(self.loaded_model_config)
154146
self.output_modes_changed.emit(
155147
self.loaded_model_config["model"].Meta.output_modes,
@@ -165,9 +157,7 @@ def load_custom_model(self, config_file):
165157
self.model_download_thread is not None
166158
and self.model_download_thread.isRunning()
167159
):
168-
print(
169-
"Another model is being loaded. Please wait for it to finish."
170-
)
160+
print("Another model is being loaded. Please wait for it to finish.")
171161
return
172162

173163
# Check config file path
@@ -191,33 +181,26 @@ def load_custom_model(self, config_file):
191181
"type" not in model_config
192182
or "display_name" not in model_config
193183
or "name" not in model_config
194-
or model_config["type"]
195-
not in ["segment_anything", "yolov5", "yolov8"]
184+
or model_config["type"] not in ["segment_anything", "yolov5", "yolov8"]
196185
):
197186
self.new_model_status.emit(
198-
self.tr(
199-
"Error in loading custom model: Invalid config file format."
200-
)
187+
self.tr("Error in loading custom model: Invalid config file format.")
201188
)
202189
return
203190

204191
# Add or replace custom model
205192
custom_models = get_config().get("custom_models", [])
206193
matched_index = None
207194
for i, model in enumerate(custom_models):
208-
if os.path.normpath(model["config_file"]) == os.path.normpath(
209-
config_file
210-
):
195+
if os.path.normpath(model["config_file"]) == os.path.normpath(config_file):
211196
matched_index = i
212197
break
213198
if matched_index is not None:
214199
model_config["last_used"] = time.time()
215200
custom_models[matched_index] = model_config
216201
else:
217202
if len(custom_models) >= self.MAX_NUM_CUSTOM_MODELS:
218-
custom_models.sort(
219-
key=lambda x: x.get("last_used", 0), reverse=True
220-
)
203+
custom_models.sort(key=lambda x: x.get("last_used", 0), reverse=True)
221204
removed_model = custom_models.pop()
222205
# Remove old model folder
223206
config_file = removed_model["config_file"]
@@ -245,9 +228,7 @@ def load_model(self, config_file):
245228
self.model_download_thread is not None
246229
and self.model_download_thread.isRunning()
247230
):
248-
print(
249-
"Another model is being loaded. Please wait for it to finish."
250-
)
231+
print("Another model is being loaded. Please wait for it to finish.")
251232
return
252233
if not config_file:
253234
if self.model_download_worker is not None:
@@ -280,16 +261,10 @@ def load_model(self, config_file):
280261
)
281262
)
282263
self.model_download_worker = GenericWorker(self._load_model, model_id)
283-
self.model_download_worker.finished.connect(
284-
self.on_model_download_finished
285-
)
286-
self.model_download_worker.finished.connect(
287-
self.model_download_thread.quit
288-
)
264+
self.model_download_worker.finished.connect(self.on_model_download_finished)
265+
self.model_download_worker.finished.connect(self.model_download_thread.quit)
289266
self.model_download_worker.moveToThread(self.model_download_thread)
290-
self.model_download_thread.started.connect(
291-
self.model_download_worker.run
292-
)
267+
self.model_download_thread.started.connect(self.model_download_worker.run)
293268
self.model_download_thread.start()
294269

295270
def _download_and_extract_model(self, model_config):
@@ -313,22 +288,16 @@ def _download_and_extract_model(self, model_config):
313288
# Download url
314289
ellipsis_download_url = download_url
315290
if len(download_url) > 40:
316-
ellipsis_download_url = (
317-
download_url[:20] + "..." + download_url[-20:]
318-
)
319-
logging.info(
320-
"Downloading %s to %s", ellipsis_download_url, zip_model_path
321-
)
291+
ellipsis_download_url = download_url[:20] + "..." + download_url[-20:]
292+
logging.info("Downloading %s to %s", ellipsis_download_url, zip_model_path)
322293
try:
323294
# Download and show progress
324295
def _progress(count, block_size, total_size):
325296
percent = int(count * block_size * 100 / total_size)
326297
self.new_model_status.emit(
327298
QCoreApplication.translate(
328299
"Model", "Downloading {download_url}: {percent}%"
329-
).format(
330-
download_url=ellipsis_download_url, percent=percent
331-
)
300+
).format(download_url=ellipsis_download_url, percent=percent)
332301
)
333302

334303
urllib.request.urlretrieve(
@@ -352,9 +321,7 @@ def _progress(count, block_size, total_size):
352321
model_folder = root
353322
break
354323
if model_folder is None:
355-
raise ValueError(
356-
self.tr("Could not find config.yaml in zip file.")
357-
)
324+
raise ValueError(self.tr("Could not find config.yaml in zip file."))
358325

359326
# Move model folder to correct location
360327
shutil.rmtree(extract_dir)
@@ -494,9 +461,9 @@ def predict_shapes(self, image, filename=None):
494461
self.prediction_finished.emit()
495462
return
496463
try:
497-
auto_labeling_result = self.loaded_model_config[
498-
"model"
499-
].predict_shapes(image, filename)
464+
auto_labeling_result = self.loaded_model_config["model"].predict_shapes(
465+
image, filename
466+
)
500467
self.new_auto_labeling_result.emit(auto_labeling_result)
501468
except Exception as e: # noqa
502469
print(f"Error in predict_shapes: {e}")
@@ -518,9 +485,7 @@ def predict_shapes_threading(self, image, filename=None):
518485
self.tr("Model is not loaded. Choose a mode to continue.")
519486
)
520487
return
521-
self.new_model_status.emit(
522-
self.tr("Inferencing AI model. Please wait...")
523-
)
488+
self.new_model_status.emit(self.tr("Inferencing AI model. Please wait..."))
524489
self.prediction_started.emit()
525490

526491
with self.model_execution_thread_lock:
@@ -530,8 +495,7 @@ def predict_shapes_threading(self, image, filename=None):
530495
):
531496
self.new_model_status.emit(
532497
self.tr(
533-
"Another model is being executed."
534-
" Please wait for it to finish."
498+
"Another model is being executed. Please wait for it to finish."
535499
)
536500
)
537501
self.prediction_finished.emit()
@@ -544,12 +508,8 @@ def predict_shapes_threading(self, image, filename=None):
544508
self.model_execution_worker.finished.connect(
545509
self.model_execution_thread.quit
546510
)
547-
self.model_execution_worker.moveToThread(
548-
self.model_execution_thread
549-
)
550-
self.model_execution_thread.started.connect(
551-
self.model_execution_worker.run
552-
)
511+
self.model_execution_worker.moveToThread(self.model_execution_thread)
512+
self.model_execution_thread.started.connect(self.model_execution_worker.run)
553513
self.model_execution_thread.start()
554514

555515
def on_next_files_changed(self, next_files):

0 commit comments

Comments
 (0)