Skip to content
This repository was archived by the owner on Dec 22, 2021. It is now read-only.

Commit 403859f

Browse files
ngzhianrossberg
andauthored
Implement JS output for SIMD (#331)
* Implement JS output for SIMD V128 is not exposed to JS at all, so transform all functions with V128 in the signatures into consts, compare with the expected value, and reduce it to a int32. An assertion against a SimdResult needs to be converted into a plain Const containing a v128. This conversion is tricky since SimdResult can contain both LitPat and NanPat. NaNs need special treatment to mask and compare to a canonical value. For simplicity, we build a mask for all the patterns in a SimdResult (even for literals, which will have a mask with all bits set). That way the test is consistent: - v128.const(mask) - v128.and - v128.const(expected) - i8x16.eq - i8x16.all_true - br_if 0 to unreachable * Formatting fixes Co-authored-by: Andreas Rossberg <rossberg@mpi-sws.org> * Remove redundant prefixes Co-authored-by: Andreas Rossberg <rossberg@mpi-sws.org>
1 parent e24bf43 commit 403859f

3 files changed

Lines changed: 68 additions & 20 deletions

File tree

interpreter/exec/simd.ml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ sig
142142
val to_i16x8 : t -> I16.t list
143143
val to_i32x4 : t -> I32.t list
144144

145+
val of_i8x16 : I32.t list -> t
146+
val of_i16x8 : I32.t list -> t
147+
val of_i32x4 : I32.t list -> t
148+
val of_i64x2 : I64.t list -> t
149+
145150
(* We need type t = t to ensure that all submodule types are S.t,
146151
* then callers don't have to change *)
147152
module I8x16 : Int with type t = t and type lane = I8.t
@@ -195,6 +200,11 @@ struct
195200
let to_i16x8 = Rep.to_i16x8
196201
let to_i32x4 = Rep.to_i32x4
197202

203+
let of_i8x16 = Rep.of_i8x16
204+
let of_i16x8 = Rep.of_i16x8
205+
let of_i32x4 = Rep.of_i32x4
206+
let of_i64x2 = Rep.of_i64x2
207+
198208
module V128 : Vec with type t = Rep.t = struct
199209
type t = Rep.t
200210
let to_shape = Rep.to_i64x2

interpreter/script/js.ml

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ let run ts at =
231231

232232
let assert_return ress ts at =
233233
let test res =
234+
let nan_bitmask_of = function
235+
| CanonicalNan -> abs_mask_of (* must only differ from the canonical NaN in its sign bit *)
236+
| ArithmeticNan -> canonical_nan_of (* can be any NaN that's one everywhere the canonical NaN is one *)
237+
in
234238
match res.it with
235239
| NumResult { it = LitPat lit; _ } ->
236240
let t', reinterpret = reinterpret_of (Values.type_of lit.it) in
@@ -246,21 +250,53 @@ let assert_return ress ts at =
246250
| Values.I32 _ | Values.I64 _ | Values.V128 _ -> assert false
247251
| Values.F32 n | Values.F64 n -> n
248252
in
249-
let nan_bitmask_of =
250-
match nan with
251-
| CanonicalNan -> abs_mask_of (* must only differ from the canonical NaN in its sign bit *)
252-
| ArithmeticNan -> canonical_nan_of (* can be any NaN that's one everywhere the canonical NaN is one *)
253-
in
254253
let t = Values.type_of nanop.it in
255254
let t', reinterpret = reinterpret_of t in
256255
[ reinterpret @@ at;
257-
Const (nan_bitmask_of t' @@ at) @@ at;
256+
Const (nan_bitmask_of nan t' @@ at) @@ at;
258257
Binary (and_of t') @@ at;
259258
Const (canonical_nan_of t' @@ at) @@ at;
260259
Compare (eq_of t') @@ at;
261260
Test (Values.I32 I32Op.Eqz) @@ at;
262261
BrIf (0l @@ at) @@ at ]
263-
| SimdResult _ -> failwith "unimplemented"
262+
| SimdResult (shape, pats) ->
263+
(* SimdResult is a list of NumPat or LitPat. For float shapes, we can have a mix of literals
264+
* and NaNs. For NaNs, we need to mask it and compare with a canonical NaN. To simplify
265+
* comparison, we build masks even for literals (will just be all set), collect them into
266+
* a v128, then compare the entire 128 bits.
267+
*)
268+
let mask_and_canonical = function
269+
| LitPat {it = Values.I32 _ as i; _} -> Values.I32 (Int32.minus_one), i
270+
| LitPat {it = Values.I64 _ as i; _} -> Values.I64 (Int64.minus_one), i
271+
| LitPat {it = Values.F32 f; _} -> Values.I32 (Int32.minus_one), Values.I32 (I32_convert.reinterpret_f32 f)
272+
| LitPat {it = Values.F64 f; _} -> Values.I64 (Int64.minus_one), Values.I64 (I64_convert.reinterpret_f64 f)
273+
| NanPat {it = Values.F32 nan; _} -> nan_bitmask_of nan I32Type, canonical_nan_of I32Type
274+
| NanPat {it = Values.F64 nan; _} -> nan_bitmask_of nan I64Type, canonical_nan_of I64Type
275+
| _ -> assert false
276+
in
277+
let masks, canons = List.split (List.map (fun p -> mask_and_canonical p.it) pats) in
278+
let all_ones = V128.of_i32x4 (List.init 4 (fun _ -> Int32.minus_one)) in
279+
let mask, expected = match shape with
280+
| Simd.I8x16 -> all_ones, V128.of_i8x16 (List.map Values.I32Value.of_value canons)
281+
| Simd.I16x8 -> all_ones, V128.of_i16x8 (List.map Values.I32Value.of_value canons)
282+
| Simd.I32x4 -> all_ones, V128.of_i32x4 (List.map Values.I32Value.of_value canons)
283+
| Simd.I64x2 -> all_ones, V128.of_i64x2 (List.map Values.I64Value.of_value canons)
284+
| Simd.F32x4 ->
285+
V128.of_i32x4 (List.map Values.I32Value.of_value masks),
286+
V128.of_i32x4 (List.map Values.I32Value.of_value canons)
287+
| Simd.F64x2 ->
288+
V128.of_i64x2 (List.map Values.I64Value.of_value masks),
289+
V128.of_i64x2 (List.map Values.I64Value.of_value canons)
290+
in
291+
[
292+
Const (Values.V128 mask @@ at) @@ at;
293+
Binary (Values.V128 V128Op.(V128 And)) @@ at;
294+
Const (Values.V128 expected @@ at) @@ at;
295+
Binary (Values.V128 V128Op.(I8x16 Eq)) @@ at;
296+
(* If all lanes are non-zero, then they are equal *)
297+
Test (Values.V128 V128Op.(I8x16 AllTrue)) @@ at;
298+
Test (Values.I32 I32Op.Eqz) @@ at;
299+
BrIf (0l @@ at) @@ at ]
264300
in [], List.flatten (List.rev_map test ress)
265301

266302
let wrap module_name item_name wrap_action wrap_assertion at =
@@ -332,21 +368,25 @@ let of_literal lit =
332368
| Values.I64 i -> "int64(\"" ^ I64.to_string_s i ^ "\")"
333369
| Values.F32 z -> of_float (F32.to_float z)
334370
| Values.F64 z -> of_float (F64.to_float z)
335-
| Values.V128 v -> failwith "TODO v128" (* FIXME should this be even valid *)
371+
| Values.V128 v -> "v128(\"" ^ V128.to_string v ^ "\")"
336372

337373
let of_nan = function
338374
| CanonicalNan -> "nan:canonical"
339375
| ArithmeticNan -> "nan:arithmetic"
340376

341-
let of_result res =
342-
match res.it with
343-
| NumResult { it = LitPat lit; _ } -> of_literal lit
344-
| SimdResult _ -> failwith "unimplemented"
345-
| NumResult { it = NanPat nanop; _ } ->
377+
let of_numpat = function
378+
| LitPat lit -> of_literal lit
379+
| NanPat nanop ->
346380
match nanop.it with
347381
| Values.I32 _ | Values.I64 _ | Values.V128 _ -> assert false
348382
| Values.F32 n | Values.F64 n -> of_nan n
349383

384+
let of_result res =
385+
match res.it with
386+
| NumResult n -> of_numpat n.it
387+
| SimdResult (shape, pats) ->
388+
Printf.sprintf "v128(\"%s\")" (String.concat " " (List.map (fun x -> of_numpat x.it) pats))
389+
350390
let rec of_definition def =
351391
match def.it with
352392
| Textual m -> of_bytes (Encode.encode m)

test/core/run.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,11 @@ def _runTestFile(self, inputPath):
9292
self._runCommand(('%s -d "%s" -o "%s"') % (wasmCommand, wasm2Path, wast2Path), logPath)
9393
self._compareFile(wastPath, wast2Path)
9494

95-
# Convert to JavaScript, SIMD has no JS support at all, so don't generate JS files.
96-
if 'simd' not in outputPath:
97-
jsPath = self._auxFile(outputPath.replace(".wast", ".js"))
98-
logPath = self._auxFile(jsPath + ".log")
99-
self._runCommand(('%s -d "%s" -o "%s"') % (wasmCommand, inputPath, jsPath), logPath)
100-
if jsCommand != None:
101-
self._runCommand(('%s "%s"') % (jsCommand, jsPath), logPath)
95+
jsPath = self._auxFile(outputPath.replace(".wast", ".js"))
96+
logPath = self._auxFile(jsPath + ".log")
97+
self._runCommand(('%s -d "%s" -o "%s"') % (wasmCommand, inputPath, jsPath), logPath)
98+
if jsCommand != None:
99+
self._runCommand(('%s "%s"') % (jsCommand, jsPath), logPath)
102100

103101

104102
if __name__ == "__main__":

0 commit comments

Comments
 (0)