Skip to content

Commit f8fc0a0

Browse files
zhihaoshan-googleZhihao Shan
andauthored
refactor slice_to_num_chips to adapt to Cloud config (#65)
Co-authored-by: Zhihao Shan <zhihaoshan@google.com>
1 parent f6751d2 commit f8fc0a0

2 files changed

Lines changed: 41 additions & 6 deletions

File tree

jetstream/core/config_lib.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import dataclasses
1818
import functools
19-
import math
2019
from typing import Any, Callable, List, Tuple, Type
2120

2221
from jetstream.engine import engine_api
@@ -85,11 +84,12 @@ class InterleavedCPUTestServer(ServerConfig):
8584

8685

8786
def slice_to_num_chips(s: str) -> int:
88-
"""Converts a TPU spec like v5e=4x2 to the number of chips, 8."""
89-
# Account for the case where it is written 'v5e:4x2'.
90-
delim = "=" if "=" in s else ":"
91-
i = math.prod([int(c) for c in s.split(delim)[1].split("x")])
92-
return i
87+
"""Converts a TPU spec like v5e-8 or v5e=8 to the number of chips, 8."""
88+
# Account for the case where it is written 'tpu=8' for compatibility.
89+
delim = "-" if "-" in s else "="
90+
# TODO: Support more accelerator type check.
91+
accelerator_type, num_devices = s.split(delim)
92+
return int(num_devices) if accelerator_type != "v4" else int(num_devices) // 2
9393

9494

9595
def _split_devices_by_slices(
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit test for config_lib.py."""
16+
17+
from absl.testing import absltest, parameterized
18+
from jetstream.core import config_lib
19+
20+
21+
class TestConfigLib(parameterized.TestCase):
22+
23+
@parameterized.parameters(
24+
("tpu=8", 8),
25+
("v5e-8", 8),
26+
("v5e=4", 4),
27+
("v4-8", 4),
28+
)
29+
def test_slice_to_num_chips(self, accelerator_slice, expected_num_devices):
30+
got = config_lib.slice_to_num_chips(accelerator_slice)
31+
self.assertEqual(got, expected_num_devices)
32+
33+
34+
if __name__ == "__main__":
35+
absltest.main()

0 commit comments

Comments
 (0)