Skip to content

Commit 03df97a

Browse files
feat: new PretrainedConfig reducing code repetition across model files
1 parent 813bc4d commit 03df97a

22 files changed

Lines changed: 440 additions & 458 deletions

src/Configs/PretrainedConfig.php

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Configs;
6+
7+
use ArrayAccess;
8+
use Codewithkyrian\Transformers\Utils\Hub;
9+
use function Codewithkyrian\Transformers\Utils\array_pick;
10+
11+
/**
12+
* The base class that implements the common methods for loading a configuration either from a local file or directory,
13+
* or from a pretrained model configuration on the Hub.
14+
*
15+
* Common attributes present in all config classes are: hidden_size, num_attention_heads, and num_hidden_layers.
16+
* Text models further implement: vocab_size.
17+
*/
18+
class PretrainedConfig implements ArrayAccess
19+
{
20+
public ?string $modelType = null;
21+
22+
public bool $isEncoderDecoder;
23+
24+
public int $maxPositionEmbeddings;
25+
26+
public array $normalizedConfig;
27+
28+
private function __construct(public array $config)
29+
{
30+
$this->modelType = $config['model_type'] ?? null;
31+
$this->isEncoderDecoder = $config['is_encoder_decoder'] ?? false;
32+
$this->maxPositionEmbeddings = $config['max_position_embeddings'] ?? 0;
33+
34+
$this->normalizedConfig = $this->getNormalizedConfig($config);
35+
}
36+
37+
public static function fromPretrained(
38+
string $modelNameOrPath,
39+
?array $config = null,
40+
?string $cacheDir = null,
41+
string $revision = 'main',
42+
?callable $onProgress = null
43+
): self
44+
{
45+
$config ??= Hub::getJson(
46+
$modelNameOrPath,
47+
fileName: 'config.json',
48+
cacheDir: $cacheDir,
49+
revision: $revision,
50+
fatal: false,
51+
onProgress: $onProgress
52+
);
53+
54+
return new self($config);
55+
}
56+
57+
public function offsetExists(mixed $offset): bool
58+
{
59+
return isset($this->config[$offset]);
60+
}
61+
62+
public function offsetGet(mixed $offset): mixed
63+
{
64+
return $this->config[$offset];
65+
}
66+
67+
public function offsetSet(mixed $offset, mixed $value): void
68+
{
69+
$this->config[$offset] = $value;
70+
}
71+
72+
public function offsetUnset(mixed $offset): void
73+
{
74+
unset($this->config[$offset]);
75+
}
76+
77+
protected function getNormalizedConfig(array $config): array
78+
{
79+
$mapping = [];
80+
$normalizedConfig = [];
81+
82+
switch ($config['model_type']) {
83+
// Sub-configs
84+
case 'llava':
85+
case 'paligemma':
86+
case 'florence2':
87+
$normalizedConfig = $this->getNormalizedConfig($config['text_config']);
88+
break;
89+
case 'moondream1':
90+
$normalizedConfig = $this->getNormalizedConfig($config['phi_config']);
91+
break;
92+
case 'musicgen':
93+
$normalizedConfig = $this->getNormalizedConfig($config['decoder']);
94+
break;
95+
96+
// Decoder-only models
97+
case 'gpt2':
98+
case 'gptj':
99+
case 'jais':
100+
case 'codegen':
101+
case 'gpt_bigcode':
102+
$mapping = [
103+
'num_heads' => 'n_head',
104+
'num_layers' => 'n_layer',
105+
'hidden_size' => 'n_embd',
106+
];
107+
break;
108+
case 'gpt_neox':
109+
case 'stablelm':
110+
case 'opt':
111+
case 'phi':
112+
case 'phi3':
113+
case 'falcon':
114+
$mapping = [
115+
'num_heads' => 'num_attention_heads',
116+
'num_layers' => 'num_hidden_layers',
117+
'hidden_size' => 'hidden_size',
118+
];
119+
break;
120+
case 'llama':
121+
case 'olmo':
122+
case 'mobilellm':
123+
case 'granite':
124+
case 'cohere':
125+
case 'mistral':
126+
case 'starcoder2':
127+
case 'qwen2':
128+
$mapping = [
129+
'num_heads' => 'num_key_value_heads',
130+
'num_layers' => 'num_hidden_layers',
131+
'hidden_size' => 'hidden_size',
132+
'num_attention_heads' => 'num_attention_heads',
133+
];
134+
break;
135+
case 'gemma':
136+
case 'gemma2':
137+
$mapping = [
138+
'num_heads' => 'num_key_value_heads',
139+
'num_layers' => 'num_hidden_layers',
140+
'dim_kv' => 'head_dim',
141+
];
142+
break;
143+
case 'openelm':
144+
$mapping = [
145+
'num_heads' => 'num_kv_heads',
146+
'num_layers' => 'num_transformer_layers',
147+
'dim_kv' => 'head_dim',
148+
];
149+
break;
150+
case 'gpt_neo':
151+
case 'donut-swin':
152+
$mapping = [
153+
'num_heads' => 'num_heads',
154+
'num_layers' => 'num_layers',
155+
'hidden_size' => 'hidden_size',
156+
];
157+
break;
158+
case 'bloom':
159+
$mapping = [
160+
'num_heads' => 'n_head',
161+
'num_layers' => 'n_layer',
162+
'hidden_size' => 'hidden_size',
163+
];
164+
break;
165+
case 'mpt':
166+
$mapping = [
167+
'num_heads' => 'n_heads',
168+
'num_layers' => 'n_layers',
169+
'hidden_size' => 'd_model',
170+
];
171+
break;
172+
173+
// Encoder-decoder models
174+
case 't5':
175+
case 'mt5':
176+
case 'longt5':
177+
$mapping = [
178+
'num_decoder_layers' => 'num_decoder_layers',
179+
'num_decoder_heads' => 'num_heads',
180+
'decoder_dim_kv' => 'd_kv',
181+
'num_encoder_layers' => 'num_layers',
182+
'num_encoder_heads' => 'num_heads',
183+
'encoder_dim_kv' => 'd_kv',
184+
];
185+
break;
186+
case 'bart':
187+
case 'mbart':
188+
case 'marian':
189+
case 'whisper':
190+
case 'm2m_100':
191+
case 'blenderbot':
192+
case 'blenderbot-small':
193+
case 'florence2_language':
194+
$mapping = [
195+
'num_decoder_layers' => 'decoder_layers',
196+
'num_decoder_heads' => 'decoder_attention_heads',
197+
'decoder_hidden_size' => 'd_model',
198+
'num_encoder_layers' => 'encoder_layers',
199+
'num_encoder_heads' => 'encoder_attention_heads',
200+
'encoder_hidden_size' => 'd_model',
201+
];
202+
break;
203+
case 'speecht5':
204+
$mapping = [
205+
'num_decoder_layers' => 'decoder_layers',
206+
'num_decoder_heads' => 'decoder_attention_heads',
207+
'decoder_hidden_size' => 'hidden_size',
208+
'num_encoder_layers' => 'encoder_layers',
209+
'num_encoder_heads' => 'encoder_attention_heads',
210+
'encoder_hidden_size' => 'hidden_size',
211+
];
212+
break;
213+
case 'trocr':
214+
$mapping = [
215+
'num_encoder_layers' => 'decoder_layers',
216+
'num_decoder_heads' => 'decoder_attention_heads',
217+
'encoder_hidden_size' => 'd_model',
218+
];
219+
break;
220+
case 'musicgen_decoder':
221+
$mapping = [
222+
'num_encoder_layers' => 'num_hidden_layers',
223+
'num_encoder_heads' => 'num_attention_heads',
224+
'encoder_hidden_size' => 'hidden_size',
225+
];
226+
break;
227+
228+
case 'vision-encoder-decoder':
229+
$decoderConfig = $this->getNormalizedConfig($config['decoder']);
230+
$addEncoderPkv = array_key_exists('num_decoder_layers', $decoderConfig);
231+
$result = array_pick($config, ['model_type', 'is_encoder_decoder']);
232+
233+
if ($addEncoderPkv) {
234+
$result = array_merge($result, [
235+
'num_decoder_layers' => $decoderConfig['num_decoder_layers'],
236+
'num_decoder_heads' => $decoderConfig['num_decoder_heads'],
237+
'decoder_hidden_size' => $decoderConfig['decoder_hidden_size'],
238+
'num_encoder_layers' => $decoderConfig['num_encoder_layers'],
239+
'num_encoder_heads' => $decoderConfig['num_encoder_heads'],
240+
'encoder_hidden_size' => $decoderConfig['encoder_hidden_size'],
241+
]);
242+
} else {
243+
$result = array_merge($result, [
244+
'num_layers' => $decoderConfig['num_layers'],
245+
'num_heads' => $decoderConfig['num_heads'],
246+
'hidden_size' => $decoderConfig['hidden_size'],
247+
]);
248+
}
249+
return $result;
250+
}
251+
252+
// If `num_attention_heads` is not set, assume it's equal to `num_heads`
253+
$normalizedConfig = array_merge(
254+
$normalizedConfig,
255+
array_pick($config, ['model_type', 'multi_query', 'is_encoder_decoder'])
256+
);
257+
258+
foreach ($mapping as $key => $value) {
259+
$normalizedConfig[$key] = $config[$value];
260+
}
261+
262+
return $normalizedConfig;
263+
}
264+
265+
266+
public function getKeyValueShapes(string $prefix = 'past_key_values'): array
267+
{
268+
$decoderFeeds = [];
269+
270+
// TODO: Support batches (i.e., batchSize > 1)
271+
$batchSize = 1;
272+
273+
if (
274+
($this->normalizedConfig['is_encoder_decoder'] ?? false) &&
275+
isset($this->normalizedConfig['num_encoder_heads'], $this->normalizedConfig['num_decoder_heads'])
276+
) {
277+
$encoderDimKv = $this->normalizedConfig['encoder_dim_kv'] ?? (
278+
$this->normalizedConfig['encoder_hidden_size'] / $this->normalizedConfig['num_encoder_heads']
279+
);
280+
$decoderDimKv = $this->normalizedConfig['decoder_dim_kv'] ?? (
281+
$this->normalizedConfig['decoder_hidden_size'] / $this->normalizedConfig['num_decoder_heads']
282+
);
283+
284+
$encoderDims = [$batchSize, $this->normalizedConfig['num_encoder_heads'], 0, $encoderDimKv];
285+
$decoderDims = [$batchSize, $this->normalizedConfig['num_decoder_heads'], 0, $decoderDimKv];
286+
for ($i = 0; $i < $this->normalizedConfig['num_decoder_layers']; ++$i) {
287+
$decoderFeeds["{$prefix}.{$i}.encoder.key"] = $encoderDims;
288+
$decoderFeeds["{$prefix}.{$i}.encoder.value"] = $encoderDims;
289+
$decoderFeeds["{$prefix}.{$i}.decoder.key"] = $decoderDims;
290+
$decoderFeeds["{$prefix}.{$i}.decoder.value"] = $decoderDims;
291+
}
292+
} else { // Decoders
293+
$numHeads = $this->normalizedConfig['num_heads'];
294+
$numLayers = $this->normalizedConfig['num_layers'];
295+
$dimKv = $this->normalizedConfig['dim_kv'] ?? (
296+
$this->normalizedConfig['hidden_size'] /
297+
($this->normalizedConfig['num_attention_heads'] ?? $numHeads)
298+
);
299+
300+
if ($this->normalizedConfig['model_type'] === 'falcon') {
301+
$shape = [$batchSize * $numHeads, 0, $dimKv];
302+
for ($i = 0; $i < $numLayers; ++$i) {
303+
$decoderFeeds["{$prefix}.{$i}.key"] = $shape;
304+
$decoderFeeds["{$prefix}.{$i}.value"] = $shape;
305+
}
306+
} elseif ($this->config['multi_query'] ?? null) { // e.g., for `gpt_bigcode`
307+
$shape = [$batchSize * $numHeads, 0, 2 * $dimKv];
308+
for ($i = 0; $i < $numLayers; ++$i) {
309+
$decoderFeeds["{$prefix}.{$i}.key_value"] = $shape;
310+
}
311+
} elseif ($this->normalizedConfig['model_type'] === 'bloom') {
312+
$keyDims = [$batchSize * $numHeads, $dimKv, 0];
313+
$valueDims = [$batchSize * $numHeads, 0, $dimKv];
314+
for ($i = 0; $i < $numLayers; ++$i) {
315+
$decoderFeeds["{$prefix}.{$i}.key"] = $keyDims;
316+
$decoderFeeds["{$prefix}.{$i}.value"] = $valueDims;
317+
}
318+
} elseif ($this->normalizedConfig['model_type'] === 'openelm') {
319+
for ($i = 0; $i < $numLayers; ++$i) {
320+
$shape = [$batchSize, $numHeads[$i], 0, $dimKv];
321+
$decoderFeeds["{$prefix}.{$i}.key"] = $shape;
322+
$decoderFeeds["{$prefix}.{$i}.value"] = $shape;
323+
}
324+
} else { // Decoder-only
325+
$shape = [$batchSize, $numHeads, 0, $dimKv];
326+
for ($i = 0; $i < $numLayers; ++$i) {
327+
$decoderFeeds["{$prefix}.{$i}.key"] = $shape;
328+
$decoderFeeds["{$prefix}.{$i}.value"] = $shape;
329+
}
330+
}
331+
}
332+
333+
return $decoderFeeds;
334+
}
335+
}

src/Models/ModelArchitecture.php

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ protected function decoderForward(PretrainedModel $model, array $modelInputs, $i
132132
$pastKeyValues = array_pop_key($modelInputs, 'past_key_values');
133133

134134
if (in_array('use_cache_branch', $inputNames)) {
135-
$modelInputs['use_cache_branch'] = new Tensor([!empty($pastKeyValues)], Tensor::bool);
135+
$modelInputs['use_cache_branch'] = new Tensor([!empty($pastKeyValues)], Tensor::bool, [1]);
136136
}
137137

138138
if (
@@ -157,10 +157,11 @@ protected function seq2seqForward(PretrainedModel $model, array $modelInputs): a
157157
$inputIds = array_pop_key($decoderFeeds, 'input_ids');
158158
$decoderInputIds = array_pop_key($decoderFeeds, 'decoder_input_ids');
159159

160-
$inputNames = array_column($model->decoderMergedSession->inputs(), 'name');
160+
161161

162162
// Encode if needed
163163
if (!$encoderOutputs) {
164+
$inputNames = array_column($model->session->inputs(), 'name');
164165
// Pick necessary encoder inputs
165166
$encoderInputs = array_pick($modelInputs, $inputNames);
166167
// Encoder outputs are not given, so we must compute them
@@ -171,6 +172,8 @@ protected function seq2seqForward(PretrainedModel $model, array $modelInputs): a
171172
$decoderFeeds['input_ids'] = $decoderInputIds;
172173
$decoderFeeds['encoder_hidden_states'] = $encoderOutputs;
173174

175+
$inputNames = array_column($model->decoderMergedSession->inputs(), 'name');
176+
174177
if (in_array('encoder_attention_mask', $inputNames)) {
175178
$decoderFeeds['encoder_attention_mask'] = $modelInputs['attention_mask'];
176179
}

0 commit comments

Comments
 (0)