@@ -167,16 +167,21 @@ def test_async_submit():
167167 assert isinstance (kern2Kernel , dpctl_prog .SyclKernel )
168168
169169 status_complete = dpctl .event_status_type .complete
170- n = 256 * 1024
171- X = dpt .empty ((3 , n ), dtype = "u4" , usm_type = "device" , sycl_queue = q )
170+
171+ # choose input size based on capability of the device
172+ f = q .sycl_device .max_work_group_size
173+ n = f * 1024
174+ n_alloc = 4 * n
175+
176+ X = dpt .empty ((3 , n_alloc ), dtype = "u4" , usm_type = "device" , sycl_queue = q )
172177 first_row = dpctl_mem .as_usm_memory (X [0 ])
173178 second_row = dpctl_mem .as_usm_memory (X [1 ])
174179 third_row = dpctl_mem .as_usm_memory (X [2 ])
175180
176181 p1 , p2 = 17 , 27
177182
178183 async_detected = False
179- for _ in range (5 ):
184+ for attempt in range (5 ):
180185 e1 = q .submit (
181186 kern1Kernel ,
182187 [
@@ -209,19 +214,22 @@ def test_async_submit():
209214 e3_st = e3 .execution_status
210215 e2_st = e2 .execution_status
211216 e1_st = e1 .execution_status
212- if not all (
213- [
214- e == status_complete
215- for e in (
216- e1_st ,
217- e2_st ,
218- e3_st ,
219- )
220- ]
221- ):
217+ are_complete = [
218+ e == status_complete
219+ for e in (
220+ e1_st ,
221+ e2_st ,
222+ e3_st ,
223+ )
224+ ]
225+ e3 . wait ()
226+ if not all ( are_complete ):
222227 async_detected = True
223- e3 .wait ()
224228 break
229+ else :
230+ n = n * (1 if attempt % 2 == 0 else 2 )
231+ if n > n_alloc :
232+ break
225233
226234 assert async_detected , "No evidence of async submission detected, unlucky?"
227235 Xnp = dpt .asnumpy (X )
@@ -231,4 +239,4 @@ def test_async_submit():
231239 Xref [1 , i ] = (i * i * i ) % p2
232240 Xref [2 , i ] = min (Xref [0 , i ], Xref [1 , i ])
233241
234- assert np .array_equal (Xnp , Xref )
242+ assert np .array_equal (Xnp [:, : n ], Xref [:, : n ] )
0 commit comments