@@ -802,3 +802,255 @@ int tq_generate_continue(tq_model_t* model,
802802 free (new_tokens );
803803 return generated ;
804804}
805+
806+ /* ============================================================================
807+ * tq_generate_chat_text — text-prefix matching for chat reuse
808+ *
809+ * Solves the BPE re-tokenization issue: when the model generates response
810+ * tokens via sample_topp, those token IDs may not match what tq_encode()
811+ * produces from the same response text in the next turn's prompt. The
812+ * token-level LCP in tq_generate_continue truncates at that boundary.
813+ *
814+ * This function tracks the *text* of the last prompt (which includes the
815+ * model's response from previous turns, accumulated by the caller). On the
816+ * next call, if the new prompt starts with cached_text byte-for-byte, the
817+ * entire cached state is valid — we tokenize only the new SUFFIX text and
818+ * prefill those tokens at positions [n_cached..]. No LCP, no truncation.
819+ *
820+ * After generation, *cached_text_io is updated to:
821+ * prompt + (generated tokens decoded back to text)
822+ * so the next call can fast-path again.
823+ *
824+ * Caller owns *cached_text_io (must free with free()).
825+ * Pass cached_text_io == NULL to disable text-prefix tracking and behave
826+ * exactly like tq_generate_continue.
827+ * ============================================================================ */
828+
829+ typedef struct {
830+ char * buf ;
831+ size_t len ;
832+ size_t cap ;
833+ void (* user_cb )(const char * , void * );
834+ void * user_data ;
835+ } chat_accum_t ;
836+
837+ static void chat_accum_callback (const char * tok , void * u ) {
838+ chat_accum_t * ctx = (chat_accum_t * )u ;
839+ if (!tok ) return ;
840+ size_t tlen = strlen (tok );
841+ if (ctx -> len + tlen + 1 > ctx -> cap ) {
842+ size_t new_cap = (ctx -> cap + tlen + 64 ) * 2 ;
843+ char * nb = (char * )realloc (ctx -> buf , new_cap );
844+ if (!nb ) return ;
845+ ctx -> buf = nb ;
846+ ctx -> cap = new_cap ;
847+ }
848+ memcpy (ctx -> buf + ctx -> len , tok , tlen );
849+ ctx -> len += tlen ;
850+ ctx -> buf [ctx -> len ] = '\0' ;
851+ if (ctx -> user_cb ) ctx -> user_cb (tok , ctx -> user_data );
852+ }
853+
854+ int tq_generate_chat_text (tq_model_t * model ,
855+ tq_tokenizer_t * tokenizer ,
856+ tq_state_t * state ,
857+ const char * prompt ,
858+ tq_gen_config_t * config ,
859+ char * * cached_text_io ,
860+ int * * cached_tokens_io ,
861+ int * n_cached_io ,
862+ int * cached_capacity_io ,
863+ char * output , int output_size ) {
864+ if (!model || !state || !config || !cached_tokens_io || !n_cached_io || !cached_capacity_io || !prompt ) {
865+ return -1 ;
866+ }
867+
868+ /* --- 1. Check for text-level prefix match --- */
869+ int matched_text_len = 0 ;
870+ int prefix_pos = 0 ; /* tokens already in KV cache that we trust */
871+
872+ if (cached_text_io && * cached_text_io && * n_cached_io > 0 ) {
873+ size_t cached_len = strlen (* cached_text_io );
874+ if (cached_len > 0 && strncmp (* cached_text_io , prompt , cached_len ) == 0 ) {
875+ matched_text_len = (int )cached_len ;
876+ prefix_pos = * n_cached_io ;
877+ } else if (getenv ("TQ_CHAT_DEBUG" )) {
878+ /* Find where they diverge to help diagnose */
879+ size_t diverge = 0 ;
880+ size_t plen = strlen (prompt );
881+ size_t lim = cached_len < plen ? cached_len : plen ;
882+ while (diverge < lim && (* cached_text_io )[diverge ] == prompt [diverge ]) diverge ++ ;
883+ fprintf (stderr ,
884+ "[chat-text] no match: cached_len=%zu prompt_len=%zu diverge_at=%zu\n"
885+ " cached[%zu..]: %.40s\n"
886+ " prompt[%zu..]: %.40s\n" ,
887+ cached_len , plen , diverge ,
888+ diverge , * cached_text_io + diverge ,
889+ diverge , prompt + diverge );
890+ }
891+ }
892+
893+ /* Wrap user callback to capture generated text into a buffer for the
894+ * next call's cached_text update. */
895+ chat_accum_t accum = { .buf = NULL , .len = 0 , .cap = 0 ,
896+ .user_cb = config -> on_token ,
897+ .user_data = config -> user_data };
898+ void (* orig_cb )(const char * , void * ) = config -> on_token ;
899+ void * orig_ud = config -> user_data ;
900+ config -> on_token = chat_accum_callback ;
901+ config -> user_data = & accum ;
902+
903+ int generated = 0 ;
904+
905+ if (matched_text_len > 0 ) {
906+ /* --- Fast path: text prefix matches --- */
907+ const char * suffix = prompt + matched_text_len ;
908+ int max_prompt = model -> config .max_seq_len > 0
909+ ? model -> config .max_seq_len : 4096 ;
910+ int * suffix_toks = (int * )malloc ((size_t )max_prompt * sizeof (int ));
911+ if (!suffix_toks ) {
912+ config -> on_token = orig_cb ; config -> user_data = orig_ud ;
913+ return -1 ;
914+ }
915+ int n_suffix = 0 ;
916+ if (* suffix != '\0' ) {
917+ n_suffix = tq_encode (tokenizer , suffix , suffix_toks , max_prompt , 0 );
918+ if (n_suffix < 0 ) n_suffix = 0 ;
919+ }
920+
921+ /* Sliding window if needed (drop from start of cached) */
922+ int reserve = config -> max_tokens > 0 ? config -> max_tokens : 256 ;
923+ if (prefix_pos + n_suffix + reserve + 32 > max_prompt ) {
924+ /* Force a full reprefill — simpler than partial cache shift */
925+ free (suffix_toks );
926+ config -> on_token = orig_cb ; config -> user_data = orig_ud ;
927+ * n_cached_io = 0 ;
928+ if (cached_text_io && * cached_text_io ) {
929+ free (* cached_text_io ); * cached_text_io = NULL ;
930+ }
931+ int n2 = tq_generate_continue (model , tokenizer , state , prompt , config ,
932+ cached_tokens_io , n_cached_io , cached_capacity_io ,
933+ output , output_size );
934+ /* fall-through path captures cached_text below */
935+ generated = n2 ;
936+ goto update_cache ;
937+ }
938+
939+ /* Grow cache buffer */
940+ int needed = prefix_pos + n_suffix + reserve + 16 ;
941+ if (* cached_capacity_io < needed ) {
942+ int new_cap = needed < 4096 ? 4096 : needed ;
943+ int * nb = (int * )realloc (* cached_tokens_io , (size_t )new_cap * sizeof (int ));
944+ if (!nb ) { free (suffix_toks ); config -> on_token = orig_cb ; config -> user_data = orig_ud ; return -1 ; }
945+ * cached_tokens_io = nb ;
946+ * cached_capacity_io = new_cap ;
947+ }
948+
949+ /* Append suffix tokens to cache + prefill at correct positions */
950+ int * cached = * cached_tokens_io ;
951+ for (int i = 0 ; i < n_suffix ; i ++ ) {
952+ cached [prefix_pos + i ] = suffix_toks [i ];
953+ tq_forward (model , state , suffix_toks [i ], prefix_pos + i );
954+ }
955+ * n_cached_io = prefix_pos + n_suffix ;
956+ free (suffix_toks );
957+
958+ if (getenv ("TQ_CHAT_DEBUG" )) {
959+ fprintf (stderr , "[chat-text] FAST text_match=%d new_suffix_tokens=%d\n" ,
960+ matched_text_len , n_suffix );
961+ }
962+
963+ /* --- Run generation loop directly --- */
964+ int vocab_size = model -> config .vocab_size ;
965+ int n_cached = * n_cached_io ;
966+ int pos = n_cached ;
967+ int prev_token = n_cached > 0 ? cached [n_cached - 1 ] : 1 ;
968+
969+ unsigned long long rng_state = config -> rng_seed
970+ ? (unsigned long long )config -> rng_seed : (unsigned long long )time (NULL );
971+ int next_token = tq_sample_topp (state -> logits , vocab_size ,
972+ config -> temperature , config -> top_p ,
973+ & rng_state );
974+
975+ int output_pos = 0 ;
976+ int eos_tokens [] = { 1 , 2 , 106 , 128001 , 128006 , 128007 , 128008 , 128009 , 248044 , 248046 };
977+ int n_eos = sizeof (eos_tokens ) / sizeof (eos_tokens [0 ]);
978+
979+ while (generated < config -> max_tokens ) {
980+ int is_eos = 0 ;
981+ for (int e = 0 ; e < n_eos ; e ++ ) {
982+ if (next_token == eos_tokens [e ]) { is_eos = 1 ; break ; }
983+ }
984+ if (is_eos ) break ;
985+ if (pos >= model -> config .max_seq_len ) break ;
986+
987+ const char * piece = tokenizer ? tq_decode (tokenizer , prev_token , next_token ) : "" ;
988+ int should_stop = 0 ;
989+ if (piece ) {
990+ if (strstr (piece , "<|im_end|>" ) || strstr (piece , "<|eot_id|>" ) ||
991+ strstr (piece , "<|start_header_id|>" )) {
992+ should_stop = 1 ; piece = "" ;
993+ }
994+ }
995+ if (should_stop ) break ;
996+
997+ int piece_len = (int )strlen (piece ? piece : "" );
998+ if (config -> on_token && piece ) config -> on_token (piece , config -> user_data );
999+ if (output && piece && output_pos + piece_len < output_size - 1 ) {
1000+ memcpy (output + output_pos , piece , piece_len );
1001+ output_pos += piece_len ;
1002+ }
1003+
1004+ if (n_cached < * cached_capacity_io ) {
1005+ cached [n_cached ++ ] = next_token ;
1006+ * n_cached_io = n_cached ;
1007+ }
1008+
1009+ prev_token = next_token ;
1010+ tq_forward (model , state , next_token , pos );
1011+ pos ++ ;
1012+ generated ++ ;
1013+
1014+ next_token = tq_sample_topp (state -> logits , vocab_size ,
1015+ config -> temperature , config -> top_p ,
1016+ & rng_state );
1017+ }
1018+
1019+ if (output && output_size > 0 ) {
1020+ output [output_pos < output_size ? output_pos : output_size - 1 ] = '\0' ;
1021+ }
1022+ } else {
1023+ /* --- Slow path: no text-prefix match, use token LCP fallback --- */
1024+ if (getenv ("TQ_CHAT_DEBUG" )) {
1025+ fprintf (stderr , "[chat-text] SLOW no text-prefix match, full tokenize\n" );
1026+ }
1027+ generated = tq_generate_continue (
1028+ model , tokenizer , state , prompt , config ,
1029+ cached_tokens_io , n_cached_io , cached_capacity_io ,
1030+ output , output_size );
1031+ }
1032+
1033+ update_cache :
1034+ /* Restore the original callback before returning to caller */
1035+ config -> on_token = orig_cb ;
1036+ config -> user_data = orig_ud ;
1037+
1038+ /* Update cached_text = prompt + generated text. The next call can
1039+ * fast-path against this if its prompt starts with this string. */
1040+ if (cached_text_io ) {
1041+ size_t plen = strlen (prompt );
1042+ size_t glen = accum .len ;
1043+ size_t new_len = plen + glen ;
1044+ char * nt = (char * )malloc (new_len + 1 );
1045+ if (nt ) {
1046+ memcpy (nt , prompt , plen );
1047+ if (glen > 0 && accum .buf ) memcpy (nt + plen , accum .buf , glen );
1048+ nt [new_len ] = '\0' ;
1049+ if (* cached_text_io ) free (* cached_text_io );
1050+ * cached_text_io = nt ;
1051+ }
1052+ }
1053+ if (accum .buf ) free (accum .buf );
1054+
1055+ return generated ;
1056+ }
0 commit comments