-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathbase_operator.py
More file actions
176 lines (143 loc) · 5.7 KB
/
base_operator.py
File metadata and controls
176 lines (143 loc) · 5.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from __future__ import annotations
import inspect
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Iterable, Tuple, Union
if TYPE_CHECKING:
import numpy as np
import pandas as pd
def convert_to_serializable(obj):
import numpy as np
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.generic):
return obj.item()
if isinstance(obj, dict):
return {k: convert_to_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [convert_to_serializable(v) for v in obj]
return obj
class BaseOperator(ABC):
def __init__(
self,
working_dir: str = "cache",
kv_backend: str = "rocksdb",
op_name: str = None,
):
# lazy import to avoid circular import
from graphgen.common.init_storage import init_storage
from graphgen.utils import set_logger
log_dir = os.path.join(working_dir, "logs")
self.op_name = op_name or self.__class__.__name__
self.working_dir = working_dir
self.kv_backend = kv_backend
self.kv_storage = init_storage(
backend=kv_backend, working_dir=working_dir, namespace=self.op_name
)
try:
import ray
ctx = ray.get_runtime_context()
worker_id = ctx.get_actor_id() or ctx.get_worker_id()
worker_id_short = worker_id[-6:] if worker_id else "driver"
except Exception as e:
print(
"Warning: Could not get Ray worker ID, defaulting to 'local'. Exception:",
e,
)
worker_id_short = "local"
# e.g. cache/logs/ChunkService_a1b2c3.log
log_file = os.path.join(log_dir, f"{self.op_name}_{worker_id_short}.log")
self.logger = set_logger(
log_file=log_file, name=f"{self.op_name}.{worker_id_short}", force=True
)
self.logger.info(
"[%s] Operator initialized on Worker %s", self.op_name, worker_id_short
)
def __call__(
self, batch: "pd.DataFrame"
) -> Union["pd.DataFrame", Iterable["pd.DataFrame"]]:
# lazy import to avoid circular import
import pandas as pd
from graphgen.utils import CURRENT_LOGGER_VAR
logger_token = CURRENT_LOGGER_VAR.set(self.logger)
try:
self.kv_storage.reload()
to_process, recovered = self.split(batch)
# yield recovered chunks first
if not recovered.empty:
yield recovered
if to_process.empty:
return
data = to_process.to_dict(orient="records")
result, meta_update = self.process(data)
if inspect.isgenerator(result):
is_first = True
for res in result:
yield pd.DataFrame([res])
self.store(
[res], meta_update if is_first else {}, flush=False
)
is_first = False
self.kv_storage.index_done_callback()
else:
yield pd.DataFrame(result)
self.store(result, meta_update)
finally:
CURRENT_LOGGER_VAR.reset(logger_token)
def get_logger(self):
return self.logger
def get_meta_forward(self):
return self.kv_storage.get_by_id("_meta_forward") or {}
def get_meta_inverse(self):
return self.kv_storage.get_by_id("_meta_inverse") or {}
def get_trace_id(self, content: dict) -> str:
from graphgen.utils import compute_dict_hash
return compute_dict_hash(content, prefix=f"{self.op_name}-")
def split(self, batch: "pd.DataFrame") -> tuple["pd.DataFrame", "pd.DataFrame"]:
"""
Split the input batch into to_process & processed based on _meta data in KV_storage
:param batch
:return:
to_process: DataFrame of documents to be chunked
recovered: Result DataFrame of already chunked documents
"""
import pandas as pd
meta_forward = self.get_meta_forward()
meta_ids = set(meta_forward.keys())
mask = batch["_trace_id"].isin(meta_ids)
to_process = batch[~mask]
processed = batch[mask]
if processed.empty:
return to_process, pd.DataFrame()
all_ids = [
pid for tid in processed["_trace_id"] for pid in meta_forward.get(tid, [])
]
recovered_chunks = self.kv_storage.get_by_ids(all_ids)
recovered_chunks = [c for c in recovered_chunks if c is not None]
return to_process, pd.DataFrame(recovered_chunks)
def store(self, results: list, meta_update: dict, flush: bool = True):
results = convert_to_serializable(results)
meta_update = convert_to_serializable(meta_update)
batch = {res["_trace_id"]: res for res in results}
self.kv_storage.upsert(batch)
# update forward meta
forward_meta = self.get_meta_forward()
forward_meta.update(meta_update)
self.kv_storage.update({"_meta_forward": forward_meta})
# update inverse meta
inverse_meta = self.get_meta_inverse()
for k, v_list in meta_update.items():
for v in v_list:
inverse_meta[v] = k
self.kv_storage.update({"_meta_inverse": inverse_meta})
if flush:
self.kv_storage.index_done_callback()
@abstractmethod
def process(self, batch: list) -> Tuple[Union[list, Iterable[dict]], dict]:
"""
Process the input batch and return the result.
:param batch
:return:
result: DataFrame of processed documents
meta_update: dict of meta data to be updated
"""