|
1 | 1 | """This module partially implements crypto for HAP.""" |
2 | 2 | import logging |
3 | 3 | import struct |
4 | | - |
| 4 | +from functools import partial |
| 5 | +from typing import List |
| 6 | +from struct import Struct |
5 | 7 | from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable as ChaCha20Poly1305 |
6 | 8 | from cryptography.hazmat.backends import default_backend |
7 | 9 | from cryptography.hazmat.primitives import hashes |
|
11 | 13 |
|
12 | 14 | CRYPTO_BACKEND = default_backend() |
13 | 15 |
|
| 16 | +PACK_NONCE = partial(Struct("<LQ").pack, 0) |
| 17 | +PACK_LENGTH = Struct("H").pack |
| 18 | + |
14 | 19 |
|
15 | 20 | class HAP_CRYPTO: |
16 | 21 | HKDF_KEYLEN = 32 # bytes, length of expanded HKDF keys |
@@ -74,51 +79,54 @@ def decrypt(self) -> bytes: |
74 | 79 | blocks are buffered locally. |
75 | 80 | """ |
76 | 81 | result = b"" |
| 82 | + crypt_in_buffer = self._crypt_in_buffer |
| 83 | + length_length = self.LENGTH_LENGTH |
| 84 | + tag_length = HAP_CRYPTO.TAG_LENGTH |
77 | 85 |
|
78 | | - while len(self._crypt_in_buffer) > self.MIN_BLOCK_LENGTH: |
79 | | - block_length_bytes = self._crypt_in_buffer[: self.LENGTH_LENGTH] |
| 86 | + while len(crypt_in_buffer) > self.MIN_BLOCK_LENGTH: |
| 87 | + block_length_bytes = crypt_in_buffer[:length_length] |
80 | 88 | block_size = struct.unpack("H", block_length_bytes)[0] |
81 | | - block_size_with_length = ( |
82 | | - self.LENGTH_LENGTH + block_size + HAP_CRYPTO.TAG_LENGTH |
83 | | - ) |
| 89 | + block_size_with_length = length_length + block_size + tag_length |
84 | 90 |
|
85 | | - if len(self._crypt_in_buffer) < block_size_with_length: |
| 91 | + if len(crypt_in_buffer) < block_size_with_length: |
86 | 92 | logger.debug("Incoming buffer does not have the full block") |
87 | 93 | return result |
88 | 94 |
|
89 | 95 | # Trim off the length |
90 | | - del self._crypt_in_buffer[: self.LENGTH_LENGTH] |
| 96 | + del crypt_in_buffer[:length_length] |
91 | 97 |
|
92 | | - data_size = block_size + HAP_CRYPTO.TAG_LENGTH |
93 | | - nonce = pad_tls_nonce(struct.pack("Q", self._in_count)) |
| 98 | + data_size = block_size + tag_length |
| 99 | + nonce = PACK_NONCE(self._in_count) |
94 | 100 |
|
95 | 101 | result += self._in_cipher.decrypt( |
96 | 102 | nonce, |
97 | | - bytes(self._crypt_in_buffer[:data_size]), |
| 103 | + bytes(crypt_in_buffer[:data_size]), |
98 | 104 | bytes(block_length_bytes), |
99 | 105 | ) |
100 | 106 |
|
101 | 107 | self._in_count += 1 |
102 | 108 |
|
103 | 109 | # Now trim out the decrypted data |
104 | | - del self._crypt_in_buffer[:data_size] |
| 110 | + del crypt_in_buffer[:data_size] |
105 | 111 |
|
106 | 112 | return result |
107 | 113 |
|
108 | 114 | def encrypt(self, data: bytes) -> bytes: |
109 | 115 | """Encrypt and send the return bytes.""" |
110 | | - result = b"" |
| 116 | + result: List[bytes] = [] |
111 | 117 | offset = 0 |
112 | 118 | total = len(data) |
113 | 119 | while offset < total: |
114 | 120 | length = min(total - offset, self.MAX_BLOCK_LENGTH) |
115 | | - length_bytes = struct.pack("H", length) |
| 121 | + length_bytes = PACK_LENGTH(length) |
116 | 122 | block = bytes(data[offset : offset + length]) |
117 | | - nonce = pad_tls_nonce(struct.pack("Q", self._out_count)) |
118 | | - ciphertext = length_bytes + self._out_cipher.encrypt( |
119 | | - nonce, block, length_bytes |
120 | | - ) |
| 123 | + nonce = PACK_NONCE(self._out_count) |
| 124 | + result.append(length_bytes) |
| 125 | + result.append(self._out_cipher.encrypt(nonce, block, length_bytes)) |
121 | 126 | offset += length |
122 | 127 | self._out_count += 1 |
123 | | - result += ciphertext |
124 | | - return result |
| 128 | + |
| 129 | + # Join the result once instead of concatenating each time |
| 130 | + # as this is much faster than generating an new immutable |
| 131 | + # byte string each time. |
| 132 | + return b"".join(result) |
0 commit comments