Skip to content

Rewrite attention sink from eviction to ring buffer (#18821)#18821

Merged
meta-codesync[bot] merged 1 commit intomainfrom
export-D100216687
Apr 13, 2026
Merged

Rewrite attention sink from eviction to ring buffer (#18821)#18821
meta-codesync[bot] merged 1 commit intomainfrom
export-D100216687

Conversation

@kirklandsign
Copy link
Copy Markdown
Contributor

@kirklandsign kirklandsign commented Apr 10, 2026

Summary:

Replace the eviction-based attention sink implementation with a torch.export
compatible ring buffer approach, and rewrite all tests.

Key changes:

  • RopeWithAttentionSink: simplified to pass through original positions (no more
    position shifting or k re-rotation)
  • KVCacheWithAttentionSink: uses ring buffer with index_copy_ instead of dynamic
    eviction (torch.cat/narrow/shift). Cache layout: [sink slots | ring buffer].
    Sets is_ring_buffer=True so AttentionMHA.forward handles masking natively.
  • CachePositionsManagerWithSink: new module that maps positions to cache indices,
    with sink tokens in fixed slots and window tokens in ring buffer region.
  • AttentionMHA.forward: ring buffer models skip start_pos bounds check and compute
    their own causal mask after KV cache update.
  • Remove eviction_batch_size from all interfaces (no longer needed).
  • Remove attention_sink_forward monkey-patch and rerotate_k dead code.
  • Add llama_attention_sink.yaml example config.
  • Rewrite 16 eviction-based tests with 18 ring buffer tests covering sink
    preservation, ring wrapping, causal masking, and degenerate cases.

Differential Revision: D100216687

Copilot AI review requested due to automatic review settings April 10, 2026 18:38
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 10, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18821

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 14 Pending

As of commit fa93c1d with merge base fe71bd4 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 10, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync bot commented Apr 10, 2026

@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D100216687.

@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Replaces the eviction-based attention sink implementation with a torch.export-compatible ring-buffer KV cache design, updates the attention path to rely on ring-buffer masking, and rewrites the associated tests/configuration.

Changes:

  • Update attention sink config parsing/validation to accept "<sink_size>,<window_size>" (removing eviction batch size).
  • Implement ring-buffer-based attention sink KV cache + cache-position management, and adjust AttentionMHA.forward to use ring-buffer masking after KV updates.
  • Rewrite attention sink tests and add an example YAML config; update BUCK deps for the new test behavior.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
extension/llm/export/config/llm_config.py Validates 2-field use_attention_sink format and updates error messaging.
examples/models/llama/source_transformation/attention_sink.py Replaces eviction approach with ring-buffer KV cache + sink-preserving cache index manager; removes forward monkey-patch.
examples/models/llama/attention.py Updates AttentionMHA.forward to treat ring-buffer caches specially (mask computed after KV update; skip bounds check).
examples/models/llama/model.py Parses 2-field attention sink config and relaxes RoPE max-context constraint to >= sink_size + window_size.
examples/models/llama/source_transformation/test_attention_sink.py Rewrites tests to cover ring-buffer sink preservation, wrapping, and masking behaviors.
examples/models/llama/config/test_llm_config.py Updates config validation tests for the new 2-field format.
examples/models/llama/config/llama_attention_sink.yaml Adds example configuration for attention sink with ring-buffer sizing guidance.
examples/models/llama/BUCK Adjusts attention_sink_test deps/preloads to support the rewritten tests.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +84 to +88
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
delta = pos_q - cache_positions

# Valid if position is filled (>= 0) and causal (delta >= 0)
is_valid = (cache_positions >= 0) & (delta >= 0)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_create_causal_mask_for_attention_sink builds pos_q via torch.arange() on the default device. If cache_positions is moved to CUDA (e.g., model.to('cuda')), this will raise a device-mismatch error when computing delta = pos_q - cache_positions. Create pos_q (and any scalar constants used in torch.where) on cache_positions.device to keep masking device-agnostic.

Copilot uses AI. Check for mistakes.
Comment on lines +138 to 143
# Sink tokens go to fixed slots; window tokens use ring buffer
indices = torch.where(
orig_indices < self.sink_size,
orig_indices,
self.sink_size + (orig_indices - self.sink_size) % self.ring_size,
)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ring-buffer index expression computes (orig_indices - sink_size) % self.ring_size unconditionally (torch.where does not short-circuit). If window_size=0 (so ring_size==0), this will raise a modulo-by-zero error even when all orig_indices are sink tokens. Add an explicit guard for ring_size==0 (either disallow window_size=0 or handle the sink-only case without modulo).

Suggested change
# Sink tokens go to fixed slots; window tokens use ring buffer
indices = torch.where(
orig_indices < self.sink_size,
orig_indices,
self.sink_size + (orig_indices - self.sink_size) % self.ring_size,
)
# torch.where does not short-circuit, so guard the sink-only case to
# avoid evaluating modulo by zero when ring_size == 0.
if self.ring_size == 0:
torch._check(
bool((orig_indices < self.sink_size).all().item()),
"Positions beyond sink_size are invalid when ring_size is 0",
)
indices = orig_indices
else:
# Sink tokens go to fixed slots; window tokens use ring buffer
indices = torch.where(
orig_indices < self.sink_size,
orig_indices,
self.sink_size + (orig_indices - self.sink_size) % self.ring_size,
)

Copilot uses AI. Check for mistakes.
Comment on lines +145 to 150
# Update cache_positions exactly like original CachePositionsManager
full_t = torch.full((self.max_context_length,), -1, dtype=torch.long)
arange_tensor = torch.arange(self.max_context_length, dtype=torch.long)
cache_positions = torch.where(
arange_tensor < start_pos, self.cache_positions, full_t
)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calculate_positions_and_update_indices constructs full_t and arange_tensor on CPU by default; if this module is moved to another device, torch.where(...) will fail due to mixed-device inputs. Allocate these tensors on self.cache_positions.device (and similarly ensure orig_indices/indices are on a consistent device) to support model.to(device).

Copilot uses AI. Check for mistakes.
Comment on lines +62 to +63
# Use torch._check for export compatibility (data-dependent guard)
torch._check(input_pos[0].item() + seq_len <= self.max_context_length)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RopeWithAttentionSink.get_freqs drops Rope.get_freqs’ torch._check_is_size(input_pos_item) guard and uses input_pos[0] instead of input_pos[-1]. This can allow negative positions (or multi-element input_pos) to slip through and fail later during narrow/indexing. Consider mirroring Rope.get_freqs’ size check and using the same element convention as the base class.

Suggested change
# Use torch._check for export compatibility (data-dependent guard)
torch._check(input_pos[0].item() + seq_len <= self.max_context_length)
input_pos_item = input_pos[-1].item()
torch._check_is_size(input_pos_item)
# Use torch._check for export compatibility (data-dependent guard)
torch._check(input_pos_item + seq_len <= self.max_context_length)

Copilot uses AI. Check for mistakes.
Comment on lines 205 to +209
attention_sink_params = self.llm_config.model.use_attention_sink.split(",")
assert len(attention_sink_params) == 3
assert len(attention_sink_params) == 2, (
f"use_attention_sink expects exactly 2 comma-separated values "
f"(sink_size,window_size), got {len(attention_sink_params)}"
)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR changes use_attention_sink from 3 parameters to 2, but other call sites still assume 3 (e.g., examples/models/llama/eval_llama_lib.py asserts len==3 around line ~350, and examples/models/llama/export_llama_lib.py’s CLI help still documents 3 values around line ~594). Please update those to avoid runtime assertion failures / misleading CLI docs.

Copilot uses AI. Check for mistakes.
Comment on lines 218 to 224
def _validate_attention_sink(self):
if self.use_attention_sink:
attention_sink_params = self.use_attention_sink.split(",")
if len(attention_sink_params) != 3:
if len(attention_sink_params) != 2:
raise ValueError(
"The value of use_attention_sink must be structured like '<sink_size>,<window_size>,<batch_eviction_size>'"
"The value of use_attention_sink must be structured like '<sink_size>,<window_size>'"
)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModelConfig’s docstring above still describes use_attention_sink as '<sink_size>,<window_size>,<batch_eviction_size>' (and gives a 3-value example), but validation now enforces exactly 2 values. Please update the documentation string/comments to match the new 2-field format to prevent confusion.

Copilot uses AI. Check for mistakes.
@meta-codesync meta-codesync bot changed the title Rewrite attention sink from eviction to ring buffer Rewrite attention sink from eviction to ring buffer (#18821) Apr 10, 2026
meta-codesync bot pushed a commit that referenced this pull request Apr 10, 2026
Summary:

Replace the eviction-based attention sink implementation with a torch.export
compatible ring buffer approach, and rewrite all tests.

Key changes:
- RopeWithAttentionSink: simplified to pass through original positions (no more
  position shifting or k re-rotation)
- KVCacheWithAttentionSink: uses ring buffer with index_copy_ instead of dynamic
  eviction (torch.cat/narrow/shift). Cache layout: [sink slots | ring buffer].
  Sets is_ring_buffer=True so AttentionMHA.forward handles masking natively.
- CachePositionsManagerWithSink: new module that maps positions to cache indices,
  with sink tokens in fixed slots and window tokens in ring buffer region.
- AttentionMHA.forward: ring buffer models skip start_pos bounds check and compute
  their own causal mask after KV cache update.
- Remove eviction_batch_size from all interfaces (no longer needed).
- Remove attention_sink_forward monkey-patch and rerotate_k dead code.
- Add llama_attention_sink.yaml example config.
- Rewrite 16 eviction-based tests with 18 ring buffer tests covering sink
  preservation, ring wrapping, causal masking, and degenerate cases.

Differential Revision: D100216687
@meta-codesync meta-codesync bot force-pushed the export-D100216687 branch from 8355623 to 973c88b Compare April 10, 2026 19:04
meta-codesync bot pushed a commit that referenced this pull request Apr 13, 2026
Summary:

Replace the eviction-based attention sink implementation with a torch.export
compatible ring buffer approach, and rewrite all tests.

Key changes:
- RopeWithAttentionSink: simplified to pass through original positions (no more
  position shifting or k re-rotation)
- KVCacheWithAttentionSink: uses ring buffer with index_copy_ instead of dynamic
  eviction (torch.cat/narrow/shift). Cache layout: [sink slots | ring buffer].
  Sets is_ring_buffer=True so AttentionMHA.forward handles masking natively.
- CachePositionsManagerWithSink: new module that maps positions to cache indices,
  with sink tokens in fixed slots and window tokens in ring buffer region.
- AttentionMHA.forward: ring buffer models skip start_pos bounds check and compute
  their own causal mask after KV cache update.
- Remove eviction_batch_size from all interfaces (no longer needed).
- Remove attention_sink_forward monkey-patch and rerotate_k dead code.
- Add llama_attention_sink.yaml example config.
- Rewrite 16 eviction-based tests with 18 ring buffer tests covering sink
  preservation, ring wrapping, causal masking, and degenerate cases.

Differential Revision: D100216687
Copilot AI review requested due to automatic review settings April 13, 2026 21:31
@meta-codesync meta-codesync bot force-pushed the export-D100216687 branch from 973c88b to a92ed43 Compare April 13, 2026 21:31
@kirklandsign kirklandsign review requested due to automatic review settings April 13, 2026 21:31
meta-codesync bot pushed a commit that referenced this pull request Apr 13, 2026
Summary:

Replace the eviction-based attention sink implementation with a torch.export
compatible ring buffer approach, and rewrite all tests.

Key changes:
- RopeWithAttentionSink: simplified to pass through original positions (no more
  position shifting or k re-rotation)
- KVCacheWithAttentionSink: uses ring buffer with index_copy_ instead of dynamic
  eviction (torch.cat/narrow/shift). Cache layout: [sink slots | ring buffer].
  Sets is_ring_buffer=True so AttentionMHA.forward handles masking natively.
- CachePositionsManagerWithSink: new module that maps positions to cache indices,
  with sink tokens in fixed slots and window tokens in ring buffer region.
- AttentionMHA.forward: ring buffer models skip start_pos bounds check and compute
  their own causal mask after KV cache update.
- Remove eviction_batch_size from all interfaces (no longer needed).
- Remove attention_sink_forward monkey-patch and rerotate_k dead code.
- Add llama_attention_sink.yaml example config.
- Rewrite 16 eviction-based tests with 18 ring buffer tests covering sink
  preservation, ring wrapping, causal masking, and degenerate cases.

Differential Revision: D100216687
Copilot AI review requested due to automatic review settings April 13, 2026 21:42
@meta-codesync meta-codesync bot force-pushed the export-D100216687 branch from a92ed43 to d12c97a Compare April 13, 2026 21:42
@kirklandsign kirklandsign review requested due to automatic review settings April 13, 2026 21:42
Summary:

Replace the eviction-based attention sink implementation with a torch.export
compatible ring buffer approach, and rewrite all tests.

Key changes:
- RopeWithAttentionSink: simplified to pass through original positions (no more
  position shifting or k re-rotation)
- KVCacheWithAttentionSink: uses ring buffer with index_copy_ instead of dynamic
  eviction (torch.cat/narrow/shift). Cache layout: [sink slots | ring buffer].
  Sets is_ring_buffer=True so AttentionMHA.forward handles masking natively.
- CachePositionsManagerWithSink: new module that maps positions to cache indices,
  with sink tokens in fixed slots and window tokens in ring buffer region.
- AttentionMHA.forward: ring buffer models skip start_pos bounds check and compute
  their own causal mask after KV cache update.
- Remove eviction_batch_size from all interfaces (no longer needed).
- Remove attention_sink_forward monkey-patch and rerotate_k dead code.
- Add llama_attention_sink.yaml example config.
- Rewrite 16 eviction-based tests with 18 ring buffer tests covering sink
  preservation, ring wrapping, causal masking, and degenerate cases.

Differential Revision: D100216687
Copilot AI review requested due to automatic review settings April 13, 2026 22:49
@meta-codesync meta-codesync bot force-pushed the export-D100216687 branch from d12c97a to fa93c1d Compare April 13, 2026 22:49
@kirklandsign kirklandsign review requested due to automatic review settings April 13, 2026 22:49
@lucylq
Copy link
Copy Markdown
Contributor

lucylq commented Apr 13, 2026

@claude review this PR

@claude
Copy link
Copy Markdown

claude bot commented Apr 13, 2026

Claude finished @lucylq's task in 1m 58s —— View job


Review of PR #18821: Rewrite attention sink from eviction to ring buffer

  • Read all changed files
  • Review attention sink implementation (ring buffer logic)
  • Review attention.py changes
  • Review model.py and config changes
  • Review test coverage
  • Check for consistency issues across codebase
  • Post review summary

This is a well-motivated rewrite. The old eviction-based approach used torch.cat/narrow/shift operations that are incompatible with torch.export, and the ring buffer design is the right architectural choice here. The implementation is clean and the test rewrite is thorough. I have a few issues to flag, roughly ordered by severity.


Issues

1. Device mismatch in _create_causal_mask_for_attention_sink and CachePositionsManagerWithSink

_create_causal_mask_for_attention_sink (attention_sink.py:82) creates pos_q via torch.arange() on the default (CPU) device. If cache_positions is on another device (e.g., CUDA), the subtraction pos_q - cache_positions will fail. The same issue applies to calculate_positions_and_update_indices (attention_sink.py:137,147-148) where orig_indices, full_t, and arange_tensor are all constructed on CPU.

For torch.export this may not matter today (tracing happens on CPU), but if the model is ever moved to another device before inference, these will break. Consider creating tensors on cache_positions.device.

Fix this →

2. RopeWithAttentionSink.get_freqs uses input_pos[0] instead of input_pos[-1]

The base Rope.get_freqs (rope.py) uses input_pos[-1] and calls torch._check_is_size. The new RopeWithAttentionSink.get_freqs (attention_sink.py:61) uses input_pos[0] and omits the torch._check_is_size guard. For single-element input_pos this is identical, but:

  • During prefill, input_pos is [0] and seq_len > 1, so input_pos[0] + seq_len correctly bounds the range.
  • The base class convention is input_pos[-1], which is the last position, not the first. With input_pos[0], the bound check input_pos[0] + seq_len <= max_context_length is correct only if input_pos is always a single element. This should be documented or the convention matched.

3. Modulo-by-zero when window_size=0 (sink-only configuration)

CachePositionsManagerWithSink.__init__ (attention_sink.py:112-113) asserts cache_size > sink_size, which means ring_size >= 1. However, if someone passes window_size=0 to KVCacheWithAttentionSink, then total_cache_size = sink_size + 0*2 = sink_size, and the assertion cache_size > sink_size will fail. This is correctly guarded by the assertion, but the error message doesn't explain why window_size=0 is invalid. Consider either:

  • Adding a clearer check in KVCacheWithAttentionSink.__init__ that window_size > 0.
  • Or documenting this requirement in the docstring.

4. cache_positions initialization difference: -1 vs 0

The original CachePositionsManager in attention.py:176 initializes cache_positions to zeros (torch.zeros), while the new CachePositionsManagerWithSink in attention_sink.py:121 initializes to -1 (torch.full(..., -1, ...)). Both use the same mask check cache_positions >= 0, so:

  • In the original: unfilled slots have cache_positions[i] = 0, which passes >= 0 and could make them visible to position 0 or later queries.
  • In the new: unfilled slots have cache_positions[i] = -1, which correctly fails >= 0 and is masked out.

The -1 initialization is the correct choice. The original CachePositionsManager may have a latent bug here (unfilled slots look like position 0), though it works in practice because position 0 is always filled first during prefill.

5. eval_llama_lib.py:354 still asserts max_seq_length == sink_size + window_size

While the PR updated the param count check at eval_llama_lib.py:350 from == 3 to == 2, line 354 still asserts:

assert llm_config.export.max_seq_length == sink_size + window_size

With the ring buffer approach, max_context_length only needs to be >= sink_size + window_size (as correctly documented in model.py:216-221). This strict equality check in the eval path may be unnecessarily restrictive. It should probably be >= for consistency, or removed if max_seq_length isn't directly tied to the cache size anymore.

Fix this →

6. ModelConfig docstring still references 3-field format

The ModelConfig.use_attention_sink docstring at llm_config.py:179-182 now correctly shows 2 fields, but the old 3-field description was present before this PR. Worth confirming the docstring was updated -- it looks correct in the current version.


Design observations (non-blocking)

  • 2x ring buffer sizing (attention_sink.py:185): The comment explaining why the ring buffer needs 2 * window_size slots is helpful and the reasoning is correct. This matches the existing RingKVCache convention.

  • evict_tokens stub (attention_sink.py:258-264): This method always returns 0, which is correct for the ring buffer approach. If no callers use it, consider removing it entirely rather than keeping a no-op stub.

  • sink_size == 0 fast path (attention_sink.py:296-305): Delegating to RingKVCache when sink_size == 0 is the right approach, avoiding unnecessary overhead.

  • Test coverage: The 18 tests cover the key scenarios well: sink preservation, ring wrapping, causal masking, sequential generation, degenerate cases, and E2E model execution. The parameterized RoPE tests are a nice touch.


Summary

The core ring buffer + sink implementation is correct and well-structured. The main actionable items are the device-agnostic tensor creation (issue 1) and the eval assertion (issue 5). The rest are minor or defensive improvements.

@meta-codesync meta-codesync bot merged commit 56d6e4d into main Apr 13, 2026
185 of 191 checks passed
@meta-codesync meta-codesync bot deleted the export-D100216687 branch April 13, 2026 23:56
jpiat pushed a commit to jpiat/executorch that referenced this pull request Apr 14, 2026
Differential Revision: D100216687

Pull Request resolved: pytorch#18821
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants