|
1 | 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates |
2 | 2 | # SPDX-License-Identifier: MIT |
3 | 3 |
|
4 | | -from typing import List, Dict, Union, Any, Optional |
5 | | -from langchain_core.outputs import LLMResult, Generation, ChatGeneration |
6 | | - |
7 | | -try: |
8 | | - import tiktoken |
9 | | - _cl100k_base_encoding = tiktoken.get_encoding('cl100k_base') |
10 | | -except Exception: |
11 | | - tiktoken = None # type: ignore[assignment] |
12 | | - _cl100k_base_encoding = None |
13 | | - |
14 | | - |
15 | | -def calc_token_usage(inputs: Union[List[Dict], LLMResult], model: str = 'gpt-3.5-turbo-0613'): |
16 | | - """Return the number of tokens used by a list of messages.""" |
17 | | - if tiktoken is None: |
18 | | - return 0 |
19 | | - try: |
20 | | - encoding = tiktoken.encoding_for_model(model) |
21 | | - except KeyError: |
22 | | - print('Warning: model not found. Using cl100k_base encoding.') |
23 | | - encoding = _cl100k_base_encoding |
24 | | - if model == 'gpt-3.5-turbo-0301': |
25 | | - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n |
26 | | - tokens_per_name = -1 # if there's a name, the role is omitted |
27 | | - elif model.startswith(('gpt-3.5', 'gpt-35', 'gpt-4')): |
28 | | - tokens_per_message = 3 |
29 | | - tokens_per_name = 1 |
30 | | - else: |
31 | | - tokens_per_message = 3 |
32 | | - tokens_per_name = 1 |
33 | | - num_tokens = 0 |
34 | | - if isinstance(inputs, List): |
35 | | - for message in inputs: |
36 | | - num_tokens += tokens_per_message |
37 | | - for key, value in message.items(): |
38 | | - num_tokens += len(encoding.encode(value)) |
39 | | - if key == 'name': |
40 | | - num_tokens += tokens_per_name |
41 | | - elif isinstance(inputs, LLMResult): |
42 | | - for inner_generations in inputs.generations: |
43 | | - for generation in inner_generations: |
44 | | - if isinstance(generation, ChatGeneration): |
45 | | - num_tokens += len(encoding.encode(generation.message.type)) + len(encoding.encode(generation.message.content)) |
46 | | - elif isinstance(generation, Generation): |
47 | | - num_tokens += len(encoding.encode('ai')) + len(encoding.encode(generation.text)) |
48 | | - num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> |
49 | | - return num_tokens |
50 | | - |
51 | | - |
| 4 | +from typing import List |
52 | 5 |
|
53 | 6 | _startswith = 'fornax_prompt_tag' |
54 | 7 |
|
|
0 commit comments