Skip to content

Commit a77b0a7

Browse files
feat: add allow_suffix to limit input files
1 parent e22d4e5 commit a77b0a7

2 files changed

Lines changed: 32 additions & 9 deletions

File tree

graphgen/graphgen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ async def read(self, read_config: Dict):
9696
"""
9797
read files from input sources
9898
"""
99-
data = read_files(read_config["input_file"], self.working_dir)
99+
data = read_files(**read_config, cache_dir=self.working_dir)
100100
if len(data) == 0:
101101
logger.warning("No data to process")
102102
return

graphgen/operators/read/read_files.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Any, Dict, List
2+
from typing import Any, Dict, List, Optional
33

44
from graphgen.models import (
55
CSVReader,
@@ -34,26 +34,49 @@ def _build_reader(suffix: str, cache_dir: str | None):
3434
return _MAPPING[suffix]()
3535

3636

37-
def read_files(file_path: str, cache_dir: str | None = None) -> list[dict]:
38-
path = Path(file_path).expanduser()
37+
def read_files(
38+
input_file: str,
39+
allowed_suffix: Optional[List[str]] = None,
40+
cache_dir: Optional[str] = None,
41+
) -> list[dict]:
42+
path = Path(input_file).expanduser()
3943
if not path.exists():
40-
raise FileNotFoundError(f"input_path not found: {file_path}")
44+
raise FileNotFoundError(f"input_path not found: {input_file}")
4145

46+
if allowed_suffix is None:
47+
support_suffix = set(_MAPPING.keys())
48+
else:
49+
support_suffix = {s.lower().lstrip(".") for s in allowed_suffix}
50+
51+
# single file
4252
if path.is_file():
43-
suffix = path.suffix.lstrip(".")
53+
suffix = path.suffix.lstrip(".").lower()
54+
if suffix not in support_suffix:
55+
logger.warning(
56+
"Skip file %s (suffix '%s' not in allowed_suffix %s)",
57+
path,
58+
suffix,
59+
support_suffix,
60+
)
61+
return []
4462
reader = _build_reader(suffix, cache_dir)
4563
return reader.read(str(path))
4664

47-
support_suffix = set(_MAPPING.keys())
65+
# folder
4866
files_to_read = [
4967
p for p in path.rglob("*") if p.suffix.lstrip(".").lower() in support_suffix
5068
]
51-
logger.info("Found %d file(s) under folder %s", len(files_to_read), file_path)
69+
logger.info(
70+
"Found %d eligible file(s) under folder %s (allowed_suffix=%s)",
71+
len(files_to_read),
72+
input_file,
73+
support_suffix,
74+
)
5275

5376
all_docs: List[Dict[str, Any]] = []
5477
for p in files_to_read:
5578
try:
56-
suffix = p.suffix.lstrip(".")
79+
suffix = p.suffix.lstrip(".").lower()
5780
reader = _build_reader(suffix, cache_dir)
5881
all_docs.extend(reader.read(str(p)))
5982
except Exception as e: # pylint: disable=broad-except

0 commit comments

Comments
 (0)