@@ -198,11 +198,17 @@ static void request_cleanup(py_request_t *req) {
198198static void process_request (py_request_t * req ) {
199199 ErlNifEnv * env = req -> env ;
200200 py_worker_t * worker = req -> worker ;
201+ py_context_t * context = req -> context ;
202+
203+ /* Extract globals/locals from context or worker */
204+ PyObject * globals = context ? context -> globals : (worker ? worker -> globals : NULL );
205+ PyObject * locals = context ? context -> locals : (worker ? worker -> locals : NULL );
201206
202207 switch (req -> type ) {
203208 case PY_REQ_CALL : {
204- /* Set thread-local worker context for callbacks */
209+ /* Set thread-local worker/ context for callbacks */
205210 tl_current_worker = worker ;
211+ tl_current_context = context ;
206212 tl_callback_env = env ;
207213 tl_allow_suspension = false; /* Blocking mode - code runs once, no replay */
208214
@@ -217,20 +223,20 @@ static void process_request(py_request_t *req) {
217223
218224 PyObject * func = NULL ;
219225
220- /* Special handling for __main__ - look in worker's namespace */
226+ /* Special handling for __main__ - look in globals/locals namespace first */
221227 if (strcmp (module_name , "__main__" ) == 0 ) {
222- func = PyDict_GetItemString (worker -> locals , func_name );
228+ func = PyDict_GetItemString (locals , func_name );
223229 if (func == NULL ) {
224- func = PyDict_GetItemString (worker -> globals , func_name );
230+ func = PyDict_GetItemString (globals , func_name );
225231 }
226232 if (func != NULL ) {
227233 Py_INCREF (func );
228- } else {
229- PyErr_Format (PyExc_NameError , "name '%s' is not defined" , func_name );
230- req -> result = make_py_error (env );
231- goto call_cleanup ;
232234 }
233- } else {
235+ /* If not found in namespace, fall through to module import below */
236+ }
237+
238+ if (func == NULL ) {
239+ /* Import module and get attribute */
234240 PyObject * module = PyImport_ImportModule (module_name );
235241 if (module == NULL ) {
236242 req -> result = make_py_error (env );
@@ -351,6 +357,7 @@ static void process_request(py_request_t *req) {
351357
352358 call_cleanup :
353359 tl_current_worker = NULL ;
360+ tl_current_context = NULL ;
354361 tl_callback_env = NULL ;
355362 tl_allow_suspension = false;
356363 enif_free (module_name );
@@ -360,6 +367,7 @@ static void process_request(py_request_t *req) {
360367
361368 case PY_REQ_EVAL : {
362369 tl_current_worker = worker ;
370+ tl_current_context = context ;
363371 tl_callback_env = env ;
364372 tl_allow_suspension = true; /* Allow suspension - we replay on resume */
365373
@@ -373,7 +381,7 @@ static void process_request(py_request_t *req) {
373381 if (enif_is_map (env , req -> locals_term )) {
374382 PyObject * new_locals = term_to_py (env , req -> locals_term );
375383 if (new_locals != NULL && PyDict_Check (new_locals )) {
376- PyDict_Update (worker -> locals , new_locals );
384+ PyDict_Update (locals , new_locals );
377385 Py_DECREF (new_locals );
378386 }
379387 }
@@ -388,7 +396,7 @@ static void process_request(py_request_t *req) {
388396 stop_timeout ();
389397 req -> result = make_py_error (env );
390398 } else {
391- PyObject * py_result = PyEval_EvalCode (compiled , worker -> globals , worker -> locals );
399+ PyObject * py_result = PyEval_EvalCode (compiled , globals , locals );
392400 Py_DECREF (compiled );
393401 stop_timeout ();
394402
@@ -451,6 +459,7 @@ static void process_request(py_request_t *req) {
451459 }
452460
453461 tl_current_worker = NULL ;
462+ tl_current_context = NULL ;
454463 tl_callback_env = NULL ;
455464 tl_allow_suspension = false;
456465 enif_free (code );
@@ -459,6 +468,7 @@ static void process_request(py_request_t *req) {
459468
460469 case PY_REQ_EXEC : {
461470 tl_current_worker = worker ;
471+ tl_current_context = context ;
462472 tl_callback_env = env ;
463473 /* Note: tl_allow_suspension stays false for exec - suspension not allowed */
464474
@@ -476,7 +486,7 @@ static void process_request(py_request_t *req) {
476486 /* Use globals for both to ensure imports are visible to defined functions.
477487 * When using separate dicts, imports go to locals but function closures
478488 * only see globals, causing "name X is not defined" errors. */
479- PyObject * py_result = PyEval_EvalCode (compiled , worker -> globals , worker -> globals );
489+ PyObject * py_result = PyEval_EvalCode (compiled , globals , globals );
480490 Py_DECREF (compiled );
481491
482492 if (py_result == NULL ) {
@@ -488,6 +498,7 @@ static void process_request(py_request_t *req) {
488498 }
489499
490500 tl_current_worker = NULL ;
501+ tl_current_context = NULL ;
491502 tl_callback_env = NULL ;
492503 enif_free (code );
493504 break ;
@@ -792,12 +803,14 @@ static int executor_enqueue(py_request_t *req) {
792803 case PY_MODE_MULTI_EXECUTOR :
793804 if (atomic_load (& g_multi_executor_initialized )) {
794805 /* Route to multi-executor pool.
795- * Use worker's assigned executor for thread affinity if available.
806+ * Use worker's or context's assigned executor for thread affinity if available.
796807 * This ensures libraries like numpy/torch that have thread-local
797- * state always run on the same thread for a given worker. */
808+ * state always run on the same thread for a given worker/context . */
798809 int exec_id ;
799810 if (req -> worker != NULL && req -> worker -> executor_id >= 0 ) {
800811 exec_id = req -> worker -> executor_id % g_num_executors ;
812+ } else if (req -> context != NULL && req -> context -> executor_id >= 0 ) {
813+ exec_id = req -> context -> executor_id % g_num_executors ;
801814 } else {
802815 exec_id = select_executor ();
803816 }
@@ -1092,3 +1105,117 @@ static void multi_executor_stop(void) {
10921105 * in executor_enqueue() using PyGILState_Ensure/Release which are no-ops
10931106 * in free-threaded builds but still work correctly.
10941107 */
1108+
1109+ /* ============================================================================
1110+ * Context dispatch to executor
1111+ *
1112+ * When a context has thread affinity (executor_id >= 0), we dispatch
1113+ * operations through the executor queue instead of executing directly
1114+ * on the dirty scheduler. This ensures numpy/torch thread-local state
1115+ * consistency.
1116+ * ============================================================================ */
1117+
1118+ /**
1119+ * Dispatch a context call operation to the executor.
1120+ *
1121+ * @param env Caller's NIF environment
1122+ * @param ctx Context with thread affinity
1123+ * @param module_bin Module name binary
1124+ * @param func_bin Function name binary
1125+ * @param args_term Arguments list
1126+ * @param kwargs_term Keyword arguments map
1127+ * @return Result term
1128+ */
1129+ ERL_NIF_TERM context_dispatch_call (ErlNifEnv * env , py_context_t * ctx ,
1130+ ErlNifBinary * module_bin , ErlNifBinary * func_bin ,
1131+ ERL_NIF_TERM args_term , ERL_NIF_TERM kwargs_term ) {
1132+ py_request_t req ;
1133+ request_init (& req );
1134+
1135+ req .type = PY_REQ_CALL ;
1136+ req .env = env ;
1137+ req .worker = NULL ;
1138+ req .context = ctx ;
1139+ req .module_bin = * module_bin ;
1140+ req .func_bin = * func_bin ;
1141+ req .args_term = args_term ;
1142+ req .kwargs_term = kwargs_term ;
1143+ req .timeout_ms = 0 ;
1144+
1145+ if (executor_enqueue (& req ) < 0 ) {
1146+ request_cleanup (& req );
1147+ return make_error (env , "executor_shutdown" );
1148+ }
1149+
1150+ executor_wait (& req );
1151+ ERL_NIF_TERM result = req .result ;
1152+ request_cleanup (& req );
1153+
1154+ return result ;
1155+ }
1156+
1157+ /**
1158+ * Dispatch a context eval operation to the executor.
1159+ *
1160+ * @param env Caller's NIF environment
1161+ * @param ctx Context with thread affinity
1162+ * @param code_bin Code string binary
1163+ * @param locals_term Local variables map
1164+ * @return Result term
1165+ */
1166+ ERL_NIF_TERM context_dispatch_eval (ErlNifEnv * env , py_context_t * ctx ,
1167+ ErlNifBinary * code_bin , ERL_NIF_TERM locals_term ) {
1168+ py_request_t req ;
1169+ request_init (& req );
1170+
1171+ req .type = PY_REQ_EVAL ;
1172+ req .env = env ;
1173+ req .worker = NULL ;
1174+ req .context = ctx ;
1175+ req .code_bin = * code_bin ;
1176+ req .locals_term = locals_term ;
1177+ req .timeout_ms = 0 ;
1178+
1179+ if (executor_enqueue (& req ) < 0 ) {
1180+ request_cleanup (& req );
1181+ return make_error (env , "executor_shutdown" );
1182+ }
1183+
1184+ executor_wait (& req );
1185+ ERL_NIF_TERM result = req .result ;
1186+ request_cleanup (& req );
1187+
1188+ return result ;
1189+ }
1190+
1191+ /**
1192+ * Dispatch a context exec operation to the executor.
1193+ *
1194+ * @param env Caller's NIF environment
1195+ * @param ctx Context with thread affinity
1196+ * @param code_bin Code string binary
1197+ * @return Result term
1198+ */
1199+ ERL_NIF_TERM context_dispatch_exec (ErlNifEnv * env , py_context_t * ctx ,
1200+ ErlNifBinary * code_bin ) {
1201+ py_request_t req ;
1202+ request_init (& req );
1203+
1204+ req .type = PY_REQ_EXEC ;
1205+ req .env = env ;
1206+ req .worker = NULL ;
1207+ req .context = ctx ;
1208+ req .code_bin = * code_bin ;
1209+ req .timeout_ms = 0 ;
1210+
1211+ if (executor_enqueue (& req ) < 0 ) {
1212+ request_cleanup (& req );
1213+ return make_error (env , "executor_shutdown" );
1214+ }
1215+
1216+ executor_wait (& req );
1217+ ERL_NIF_TERM result = req .result ;
1218+ request_cleanup (& req );
1219+
1220+ return result ;
1221+ }
0 commit comments