|
| 1 | +import torch |
| 2 | +from torch import nn |
| 3 | + |
1 | 4 | from labml import logger, lab, monit |
2 | 5 | from labml.logger import Text, Style |
3 | | -from python_autocomplete.evaluate import Predictor |
4 | | -from python_autocomplete.evaluate.factory import get_predictor |
| 6 | +from labml_helpers.module import Module |
| 7 | +from python_autocomplete.dataset import Tokenizer |
| 8 | +from python_autocomplete.evaluate.factory import load_experiment |
| 9 | +from python_autocomplete.train import StateUpdater |
| 10 | + |
5 | 11 |
|
| 12 | +def anomalies(tokenizer: Tokenizer, text: str, model: Module, state_updater: StateUpdater, is_token_by_token: bool): |
| 13 | + tokens = tokenizer.encode(text) |
6 | 14 |
|
7 | | -def anomalies(predictor: Predictor, text: str): |
8 | 15 | line_no = 1 |
9 | | - logs = [(f"{line_no: 4d}: ", Text.meta), (text[0], Text.subtle)] |
| 16 | + logs = [(f"{line_no: 4d}: ", Text.meta), (tokenizer.itos[tokens[0]], Style.bold)] |
| 17 | + |
| 18 | + text = torch.tensor(tokens, dtype=torch.long, device=model.device) |
| 19 | + prompt = text[:1].unsqueeze(-1) |
| 20 | + |
| 21 | + state = None |
| 22 | + softmax = nn.Softmax(-1) |
10 | 23 |
|
11 | | - i = 0 |
| 24 | + i = 1 |
12 | 25 |
|
13 | 26 | while i + 1 < len(text): |
14 | | - # print(i, self.predictor.prompt) |
15 | | - preds, _ = predictor.get_predictions(text[:i + 1], None, calc_probs=True) |
16 | | - preds = preds[0, :] |
17 | | - c = text[i + 1] |
18 | | - |
19 | | - if c == '\n': |
20 | | - logger.log(logs) |
21 | | - line_no += 1 |
22 | | - logs = [(f"{line_no: 4d}: ", Text.meta)] |
23 | | - elif c == '\r': |
24 | | - continue |
25 | | - elif c not in predictor.tokenizer.stoi: |
26 | | - logs.append(c) |
| 27 | + with torch.no_grad(): |
| 28 | + prediction, new_state = model(prompt, state) |
| 29 | + |
| 30 | + state = state_updater(state, new_state) |
| 31 | + prediction = softmax(prediction[-1, 0]) |
| 32 | + |
| 33 | + if is_token_by_token: |
| 34 | + prompt = text[i: i + 1].unsqueeze(-1) |
27 | 35 | else: |
28 | | - next_id = predictor.tokenizer.stoi[c] |
29 | | - prob = preds[next_id] |
30 | | - if prob > 0.9: |
31 | | - logs.append((c, [Style.bold, Text.success, Style.underline])) |
32 | | - elif prob > 0.75: |
33 | | - logs.append((c, [Text.success, Style.underline])) |
34 | | - elif prob > 0.2: |
35 | | - logs.append(c) |
36 | | - elif prob > 0.1: |
37 | | - logs.append((c, [Text.warning, Style.underline])) |
38 | | - elif prob > 0.01: |
39 | | - logs.append((c, [Style.bold, Text.warning, Style.underline])) |
40 | | - elif prob > 0.001: |
41 | | - logs.append((c, [Text.danger, Style.underline])) |
| 36 | + prompt = text[:i + 1] |
| 37 | + prompt = prompt[-512:].unsqueeze(-1) |
| 38 | + |
| 39 | + token_str = tokenizer.itos[text[i]] |
| 40 | + prob = prediction[text[i]].item() |
| 41 | + |
| 42 | + for c in token_str: |
| 43 | + if c == '\n': |
| 44 | + logger.log(logs) |
| 45 | + line_no += 1 |
| 46 | + logs = [(f"{line_no: 4d}: ", Text.meta)] |
| 47 | + elif c == '\r': |
| 48 | + continue |
42 | 49 | else: |
43 | | - logs.append((c, [Style.bold, Text.danger, Style.underline])) |
| 50 | + if prob > 0.9: |
| 51 | + logs.append((c, [Text.subtle, Style.underline])) |
| 52 | + elif prob > 0.75: |
| 53 | + logs.append((c, [Text.success, Style.underline])) |
| 54 | + elif prob > 0.2: |
| 55 | + logs.append(c) |
| 56 | + elif prob > 0.1: |
| 57 | + logs.append((c, [Text.warning, Style.underline])) |
| 58 | + elif prob > 0.01: |
| 59 | + logs.append((c, [Style.bold, Text.warning, Style.underline])) |
| 60 | + elif prob > 0.001: |
| 61 | + logs.append((c, [Text.danger, Style.underline])) |
| 62 | + else: |
| 63 | + logs.append((c, [Style.bold, Text.danger, Style.underline])) |
44 | 64 |
|
45 | 65 | i += 1 |
46 | 66 |
|
47 | 67 | logger.log(logs) |
48 | 68 |
|
49 | 69 |
|
50 | 70 | def main(): |
51 | | - predictor = get_predictor() |
| 71 | + conf = load_experiment() |
52 | 72 |
|
53 | 73 | with open(str(lab.get_data_path() / 'sample.py'), 'r') as f: |
54 | 74 | sample = f.read() |
55 | 75 | with monit.section('Anomalies'): |
56 | | - anomalies(predictor, sample) |
| 76 | + anomalies(conf.text.tokenizer, sample, conf.model, conf.state_updater, conf.is_token_by_token) |
57 | 77 |
|
58 | 78 |
|
59 | 79 | if __name__ == '__main__': |
|
0 commit comments