Skip to content

Commit 66ea330

Browse files
committed
Expand arg passing test with host and wasm return values
There's enough test cases that this warrants being run in parallel via its own test target. Signed-off-by: Matt Leon <mattleon@google.com>
1 parent e703840 commit 66ea330

6 files changed

Lines changed: 235 additions & 56 deletions

File tree

include/proxy-wasm/wasm_vm.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,21 @@ template <size_t N>
6767
using WasmCallVoid = std::function<WasmCallInFuncType<N, void, ContextBase *, Word>>;
6868
template <size_t N>
6969
using WasmCallWord = std::function<WasmCallInFuncType<N, Word, ContextBase *, Word>>;
70+
// Callback used to test arg passing from host to wasm.
7071
using WasmCall_WWlfd = std::function<Word(ContextBase *, Word, uint64_t, float, double)>;
72+
// Types used to test return values. Floats are passed as parameters as these
73+
// do not conflict with ProxyWasm ABI signatures.
74+
using WasmCall_lf = std::function<uint64_t(ContextBase *, float)>;
75+
using WasmCall_fff = std::function<float(ContextBase *, float, float)>;
76+
using WasmCall_dfff = std::function<double(ContextBase *, float, float, float)>;
7177

7278
#define FOR_ALL_WASM_VM_EXPORTS(_f) \
7379
_f(proxy_wasm::WasmCallVoid<0>) _f(proxy_wasm::WasmCallVoid<1>) _f(proxy_wasm::WasmCallVoid<2>) \
7480
_f(proxy_wasm::WasmCallVoid<3>) _f(proxy_wasm::WasmCallVoid<5>) \
7581
_f(proxy_wasm::WasmCallWord<0>) _f(proxy_wasm::WasmCallWord<1>) \
7682
_f(proxy_wasm::WasmCallWord<2>) _f(proxy_wasm::WasmCallWord<3>) \
77-
_f(proxy_wasm::WasmCall_WWlfd)
83+
_f(proxy_wasm::WasmCall_WWlfd) _f(proxy_wasm::WasmCall_lf) \
84+
_f(proxy_wasm::WasmCall_fff) _f(proxy_wasm::WasmCall_dfff)
7885

7986
// These are templates and its helper for constructing signatures of functions callbacks from Wasm
8087
// VMs.

test/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,22 @@ cc_test(
8585
],
8686
)
8787

88+
cc_test(
89+
name = "arg_passing_test",
90+
timeout = "long",
91+
srcs = ["arg_passing_test.cc"],
92+
data = [
93+
"//test/test_data:arg_passing.wasm",
94+
],
95+
linkstatic = 1,
96+
deps = [
97+
":utility_lib",
98+
"//:lib",
99+
"@com_google_googletest//:gtest",
100+
"@com_google_googletest//:gtest_main",
101+
],
102+
)
103+
88104
cc_test(
89105
name = "exports_test",
90106
srcs = ["exports_test.cc"],

test/arg_passing_test.cc

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "gtest/gtest.h"
16+
#include "gmock/gmock.h"
17+
18+
#include <optional>
19+
20+
#include "include/proxy-wasm/context.h"
21+
#include "include/proxy-wasm/wasm.h"
22+
23+
#include "test/utility.h"
24+
25+
namespace proxy_wasm {
26+
namespace {
27+
28+
class ArgPassingContext : public TestContext {
29+
public:
30+
using TestContext::TestContext;
31+
WasmResult getHeaderMapPairs(WasmHeaderMapType /* type */, Pairs * /* result */) override {
32+
return static_cast<WasmResult>(3333333333U);
33+
}
34+
};
35+
36+
class ArgPassingWasm : public TestWasm {
37+
public:
38+
using TestWasm::TestWasm;
39+
ContextBase *createVmContext() override { return new ArgPassingContext(this); };
40+
};
41+
42+
class ArgPassingTest : public TestVm {
43+
public:
44+
void SetUp() {
45+
auto source = readTestWasmFile("arg_passing.wasm");
46+
ASSERT_FALSE(source.empty());
47+
wasm_.emplace(std::move(vm_));
48+
ASSERT_TRUE(wasm_->load(source, false));
49+
ASSERT_TRUE(wasm_->initialize());
50+
context_ = dynamic_cast<ArgPassingContext *>(wasm_->vm_context());
51+
ASSERT_NE(context_, nullptr);
52+
}
53+
54+
std::optional<ArgPassingWasm> wasm_;
55+
ArgPassingContext *context_;
56+
};
57+
58+
INSTANTIATE_TEST_SUITE_P(WasmEngines, ArgPassingTest, testing::ValuesIn(getWasmEngines()),
59+
[](const testing::TestParamInfo<std::string> &info) {
60+
return info.param;
61+
});
62+
63+
TEST_P(ArgPassingTest, WasmCallReturnsWordValue) {
64+
WasmCallWord<0> test_return_u32;
65+
wasm_->wasm_vm()->getFunction("test_return_u32", &test_return_u32);
66+
67+
EXPECT_EQ(test_return_u32(context_), 3333333333U) << context_->getLog();
68+
}
69+
70+
TEST_P(ArgPassingTest, WasmCallReturnsLongValue) {
71+
WasmCall_lf test_return_u64;
72+
wasm_->wasm_vm()->getFunction("test_return_u64", &test_return_u64);
73+
74+
EXPECT_EQ(test_return_u64(context_, 1.0), 11111111111111111111UL) << context_->getLog();
75+
}
76+
77+
TEST_P(ArgPassingTest, WasmCallReturnsFloatValue) {
78+
WasmCall_fff test_return_f32;
79+
wasm_->wasm_vm()->getFunction("test_return_f32", &test_return_f32);
80+
81+
EXPECT_THAT(test_return_f32(context_, 1.0, 1.0),
82+
testing::AllOf(testing::Lt(1112.0), testing::Gt(1110.0)))
83+
<< context_->getLog();
84+
}
85+
86+
TEST_P(ArgPassingTest, WasmCallReturnsDoubleValue) {
87+
WasmCall_dfff test_return_f64;
88+
wasm_->wasm_vm()->getFunction("test_return_f64", &test_return_f64);
89+
90+
EXPECT_THAT(test_return_f64(context_, 1.0, 1.0, 1.0),
91+
testing::AllOf(testing::Lt(1111111112.0), testing::Gt(1111111110.0)))
92+
<< context_->getLog();
93+
}
94+
95+
TEST_P(ArgPassingTest, HostCallReturnsWordValue) {
96+
WasmCallWord<0> test_host_return;
97+
wasm_->wasm_vm()->getFunction("test_host_return", &test_host_return);
98+
99+
EXPECT_TRUE(test_host_return(context_)) << context_->getLog();
100+
}
101+
102+
TEST_P(ArgPassingTest, HostPassesPrimitiveValues) {
103+
WasmCall_WWlfd test_primitives;
104+
wasm_->wasm_vm()->getFunction("test_primitives", &test_primitives);
105+
106+
ASSERT_TRUE(test_primitives(context_, 3333333333U, 11111111111111111111UL, 1111, 1111111111))
107+
<< context_->getLog();
108+
}
109+
110+
TEST_P(ArgPassingTest, HostPassesNegativePrimitiveValues) {
111+
WasmCall_WWlfd test_negative_primitives;
112+
wasm_->wasm_vm()->getFunction("test_negative_primitives", &test_negative_primitives);
113+
114+
ASSERT_TRUE(
115+
test_negative_primitives(context_, -1111111111, -1111111111111111111, -1111, -1111111111))
116+
<< context_->getLog();
117+
}
118+
119+
TEST_P(ArgPassingTest, HostReadsPointersToWasmMemory) {
120+
WasmCallWord<0> test_buffer_from_wasm;
121+
wasm_->wasm_vm()->getFunction("test_buffer_from_wasm", &test_buffer_from_wasm);
122+
123+
ASSERT_TRUE(test_buffer_from_wasm(context_)) << context_->getLog();
124+
125+
context_->isLogged("hello from wasm land!");
126+
}
127+
128+
TEST_P(ArgPassingTest, WasmCallReadsBufferPassedByHost) {
129+
context_->setBuffer(0, "hello from host land!");
130+
WasmCallWord<0> test_buffer_from_host;
131+
wasm_->wasm_vm()->getFunction("test_buffer_from_host", &test_buffer_from_host);
132+
133+
ASSERT_TRUE(test_buffer_from_host(context_)) << context_->getLog();
134+
}
135+
136+
} // namespace
137+
} // namespace proxy_wasm

test/runtime_test.cc

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -190,34 +190,6 @@ TEST_P(TestVm, Trap2) {
190190
}
191191
}
192192

193-
TEST_P(TestVm, PassingValuesAcrossWasmBoundary) {
194-
auto source = readTestWasmFile("arg_passing.wasm");
195-
ASSERT_FALSE(source.empty());
196-
auto wasm = TestWasm(std::move(vm_));
197-
ASSERT_TRUE(wasm.load(source, false));
198-
ASSERT_TRUE(wasm.initialize());
199-
auto *context = dynamic_cast<TestContext *>(wasm.vm_context());
200-
ASSERT_NE(context, nullptr);
201-
WasmCall_WWlfd test_primitives;
202-
wasm.wasm_vm()->getFunction("test_primitives", &test_primitives);
203-
WasmCall_WWlfd test_negative_primitives;
204-
wasm.wasm_vm()->getFunction("test_negative_primitives", &test_negative_primitives);
205-
WasmCallWord<0> test_buffer_from_wasm;
206-
wasm.wasm_vm()->getFunction("test_buffer_from_wasm", &test_buffer_from_wasm);
207-
WasmCallWord<0> test_buffer_from_host;
208-
wasm.wasm_vm()->getFunction("test_buffer_from_host", &test_buffer_from_host);
209-
210-
ASSERT_FALSE(test_primitives(context, 3333333333U, 11111111111111111111UL, 1111, 1111111111));
211-
ASSERT_FALSE(
212-
test_negative_primitives(context, -1111111111, -1111111111111111111, -1111, -1111111111));
213-
214-
ASSERT_FALSE(test_buffer_from_wasm(context));
215-
context->isLogged("hello from wasm land!");
216-
217-
context->setBuffer(0, "hello from host land!");
218-
ASSERT_FALSE(test_buffer_from_host(context));
219-
}
220-
221193
class TestCounterContext : public TestContext {
222194
public:
223195
TestCounterContext(WasmBase *wasm) : TestContext(wasm) {}

test/test_data/arg_passing.rs

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,28 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::time::{SystemTime, UNIX_EPOCH};
16+
17+
extern "C" {
18+
fn proxy_log(level: u32, message_data: *const u8, message_size: usize) -> bool;
19+
}
20+
21+
fn log(message: &str) {
22+
unsafe {
23+
proxy_log(/*error*/ 4, message.as_bytes().as_ptr(), message.len());
24+
}
25+
}
26+
27+
#[no_mangle]
28+
pub extern "C" fn _initialize() {
29+
std::panic::set_hook(Box::new(|panic_info| {
30+
log(&format!(
31+
"panic message: {}",
32+
panic_info.payload_as_str().unwrap_or("")
33+
));
34+
}));
35+
}
36+
1537
#[no_mangle]
1638
pub extern "C" fn proxy_abi_version_0_2_0() {}
1739

@@ -26,34 +48,60 @@ pub extern "C" fn proxy_on_memory_allocate(size: usize) -> *mut u8 {
2648
}
2749

2850
extern "C" {
29-
fn proxy_log(level: u32, message_data: *const u8, message_size: usize) -> bool;
51+
// Used by test_host_return to assert on values returned from imports from the host.
52+
fn proxy_get_header_map_pairs(
53+
map_type: u32,
54+
return_map_data: *mut *mut u8,
55+
return_map_size: *mut usize,
56+
) -> u32;
3057
}
3158

32-
fn log(message: &str) {
59+
#[no_mangle]
60+
pub extern "C" fn test_return_u32() -> u32 {
61+
return 3333333333;
62+
}
63+
64+
#[no_mangle]
65+
pub extern "C" fn test_return_u64(_: f32) -> u64 {
66+
return 11111111111111111111;
67+
}
68+
69+
#[no_mangle]
70+
pub extern "C" fn test_return_f32(_: f32, _: f32) -> f32 {
71+
return 1111.0f32;
72+
}
73+
74+
#[no_mangle]
75+
pub extern "C" fn test_return_f64(_: f32, _: f32, _: f32) -> f64 {
76+
return 1111111111.0f64;
77+
}
78+
79+
#[no_mangle]
80+
pub extern "C" fn test_host_return() -> u32 {
3381
unsafe {
34-
proxy_log(/*error*/ 4, message.as_bytes().as_ptr(), message.len());
82+
let ret = proxy_get_header_map_pairs(0, std::ptr::null_mut(), std::ptr::null_mut());
83+
if ret != 3333333333u32 {
84+
panic!("unexpected get_header_map_pairs return value: {}", ret);
85+
}
3586
}
87+
return 1;
3688
}
3789

3890
#[no_mangle]
3991
pub extern "C" fn test_primitives(uint32: u32, uint64: u64, float32: f32, float64: f64) -> i32 {
4092
if uint32 != 3333333333 {
41-
log(&format!("unexpected uint32 value: {}", uint32));
42-
return 1;
93+
panic!("unexpected uint32 value: {}", uint32);
4394
}
4495
if uint64 != 11111111111111111111 {
45-
log(&format!("unexpected uint64 value: {}", uint32));
46-
return 2;
96+
panic!("unexpected uint64 value: {}", uint64);
4797
}
4898
if float32 < 1110.0 || float32 > 1112.0 {
49-
log(&format!("unexpected float32 value: {}", float32));
50-
return 3;
99+
panic!("unexpected float32 value: {}", float32);
51100
}
52101
if float64 < 1111111110.0 || float64 > 1111111112.0 {
53-
log(&format!("unexpected float64 value: {}", float64));
54-
return 4;
102+
panic!("unexpected float64 value: {}", float64);
55103
}
56-
return 0;
104+
return 1;
57105
}
58106

59107
#[no_mangle]
@@ -64,30 +112,26 @@ pub extern "C" fn test_negative_primitives(
64112
float64: f64,
65113
) -> i32 {
66114
if int32 != -1111111111 {
67-
log(&format!("unexpected int32 value: {}", int32));
68-
return 1;
115+
panic!("unexpected int32 value: {}", int32);
69116
}
70117
if int64 != -1111111111111111111 {
71-
log(&format!("unexpected int64 value: {}", int32));
72-
return 2;
118+
panic!("unexpected int64 value: {}", int64);
73119
}
74120
if float32 > -1110.0 || float32 < -1112.0 {
75-
log(&format!("unexpected float32 value: {}", float32));
76-
return 3;
121+
panic!("unexpected float32 value: {}", float32);
77122
}
78123
if float64 > -1111111110.0 || float64 < -1111111112.0 {
79-
log(&format!("unexpected float64 value: {}", float64));
80-
return 4;
124+
panic!("unexpected float64 value: {}", float64);
81125
}
82-
return 0;
126+
return 1;
83127
}
84128

85129
#[no_mangle]
86130
pub extern "C" fn test_buffer_from_wasm() -> bool {
87131
let message = "hello from wasm land!";
88132
unsafe {
89133
match proxy_log(/*info*/ 2, message.as_ptr(), message.len()) {
90-
false => false,
134+
false => true,
91135
status => panic!("unexpected status: {}", status as u32),
92136
}
93137
}
@@ -108,17 +152,18 @@ pub extern "C" fn test_buffer_from_host() -> bool {
108152
let mut return_data: *mut u8 = std::ptr::null_mut();
109153
let mut return_size: usize = 0;
110154
unsafe {
111-
match proxy_get_buffer_bytes(0, 0, 10, &mut return_data, &mut return_size) {
155+
match proxy_get_buffer_bytes(0, 0, 30, &mut return_data, &mut return_size) {
112156
false => {
113157
if return_data.is_null() {
114158
panic!("return_data was null");
115159
}
116160
let result =
117-
String::from_utf8(Vec::from_raw_parts(return_data, return_size, return_size));
118-
if result.unwrap() != "hello from host land" {
119-
panic!("message did not match expectation");
161+
String::from_utf8(Vec::from_raw_parts(return_data, return_size, return_size))
162+
.unwrap();
163+
if result != "hello from host land!\0" {
164+
panic!("message {} did not match expectation", result);
120165
}
121-
false
166+
true
122167
}
123168
status => panic!("unexpected status: {}", status as u32),
124169
}

test/utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class TestContext : public ContextBase {
102102
return WasmResult::Ok;
103103
}
104104

105+
std::string_view getLog() const { return log_; }
106+
105107
WasmResult getProperty(std::string_view path, std::string *result) override {
106108
if (path == "plugin_root_id") {
107109
*result = root_id_;

0 commit comments

Comments
 (0)