|
15 | 15 | # limitations under the License. |
16 | 16 | # |
17 | 17 | import logging |
| 18 | +import platform |
| 19 | +import re |
18 | 20 | import unittest |
19 | 21 |
|
20 | | -from test_env import timestamp, TPP_TOKEN_URL, TPP_USER, TPP_PASSWORD, SSH_CADN |
| 22 | +from assets import SSH_CERT_DATA, SSH_PRIVATE_KEY, SSH_PUBLIC_KEY |
| 23 | +from test_env import timestamp, TPP_TOKEN_URL, TPP_USER, TPP_PASSWORD, TPP_SSH_CADN |
21 | 24 | from vcert import CommonConnection, SSHCertRequest, TPPTokenConnection, Authentication, \ |
22 | | - SCOPE_SSH, generate_ssh_keypair |
23 | | -from vcert.ssh_utils import SSHRetrieveResponse |
| 25 | + SCOPE_SSH, write_ssh_files |
| 26 | +from vcert.ssh_utils import SSHRetrieveResponse, SSHKeyPair |
24 | 27 |
|
25 | 28 | logging.basicConfig(level=logging.DEBUG) |
26 | 29 | logger = logging.getLogger('vcert-test') |
|
32 | 35 |
|
33 | 36 | class TestTPPSSHCertificate(unittest.TestCase): |
34 | 37 | def __init__(self, *args, **kwargs): |
35 | | - self.tpp_conn = TPPTokenConnection(url=TPP_TOKEN_URL, http_request_kwargs={"verify": "/tmp/chain.pem"}) |
| 38 | + self.tpp_conn = TPPTokenConnection(url=TPP_TOKEN_URL, http_request_kwargs={"verify": False}) |
36 | 39 | auth = Authentication(user=TPP_USER, password=TPP_PASSWORD, scope=SCOPE_SSH) |
37 | 40 | self.tpp_conn.get_access_token(auth) |
38 | 41 | super(TestTPPSSHCertificate, self).__init__(*args, **kwargs) |
39 | 42 |
|
40 | 43 | def test_enroll_local_generated_keypair(self): |
41 | | - keypair = generate_ssh_keypair(key_size=4096, passphrase="foobar") |
| 44 | + keypair = SSHKeyPair() |
| 45 | + keypair.generate(key_size=4096, passphrase="foobar") |
42 | 46 |
|
43 | | - request = SSHCertRequest(cadn=SSH_CADN, key_id=_random_key_id()) |
| 47 | + request = SSHCertRequest(cadn=TPP_SSH_CADN, key_id=_random_key_id()) |
44 | 48 | request.validity_period = "4h" |
45 | 49 | request.source_addresses = ["test.com"] |
46 | | - request.set_public_key_data(keypair.public_key) |
| 50 | + request.set_public_key_data(keypair.public_key()) |
47 | 51 | response = _enroll_ssh_cert(self.tpp_conn, request) |
48 | 52 | self.assertTrue(response.private_key_data is None, |
49 | 53 | SERVICE_GENERATED_NO_KEY_ERROR % ("Private", "not", request.key_id)) |
50 | 54 | self.assertTrue(response.public_key_data, SERVICE_GENERATED_NO_KEY_ERROR % ("Public", "", request.key_id)) |
51 | 55 | self.assertTrue(response.public_key_data == request.get_public_key_data(), |
52 | 56 | "Public key on response does not match request.\nExpected: %s\nGot: %s" |
53 | 57 | % (request.get_public_key_data(), response.public_key_data)) |
54 | | - self.assertTrue(response.cert_data, SSH_CERT_DATA_ERROR % request.key_id) |
| 58 | + self.assertTrue(response.certificate_data, SSH_CERT_DATA_ERROR % request.key_id) |
55 | 59 |
|
56 | 60 | def test_enroll_service_generated_keypair(self): |
57 | | - request = SSHCertRequest(cadn=SSH_CADN, key_id=_random_key_id()) |
| 61 | + request = SSHCertRequest(cadn=TPP_SSH_CADN, key_id=_random_key_id()) |
58 | 62 | request.validity_period = "4h" |
59 | 63 | request.source_addresses = ["test.com"] |
60 | 64 | response = _enroll_ssh_cert(self.tpp_conn, request) |
61 | 65 | self.assertTrue(response.private_key_data, SERVICE_GENERATED_NO_KEY_ERROR % ("Private", "", request.key_id)) |
62 | 66 | self.assertTrue(response.public_key_data, SERVICE_GENERATED_NO_KEY_ERROR % ("Public", "", request.key_id)) |
63 | | - self.assertTrue(response.cert_data, SSH_CERT_DATA_ERROR % request.key_id) |
| 67 | + self.assertTrue(response.certificate_data, SSH_CERT_DATA_ERROR % request.key_id) |
| 68 | + |
| 69 | + |
| 70 | +class TestSSHUtils(unittest.TestCase): |
| 71 | + |
| 72 | + def test_write_ssh_files(self): |
| 73 | + key_id = _random_key_id() |
| 74 | + normalized_name = re.sub(r"[^A-Za-z0-9]+", "_", key_id) |
| 75 | + full_path = "./" + normalized_name |
| 76 | + write_ssh_files("./", key_id, SSH_CERT_DATA, SSH_PRIVATE_KEY, SSH_PUBLIC_KEY) |
| 77 | + |
| 78 | + err_msg = "%s serialization does not match expected value" |
| 79 | + |
| 80 | + with open(full_path + "-cert.pub", "r") as cert_file: |
| 81 | + s_cert = cert_file.read() |
| 82 | + self.assertTrue(SSH_CERT_DATA == s_cert, err_msg % "SSH Certificate") |
| 83 | + |
| 84 | + with open(full_path, "r") as priv_key_file: |
| 85 | + s_priv_key = priv_key_file.read() |
| 86 | + expected_priv_key = SSH_PRIVATE_KEY |
| 87 | + if platform.system() is not "Windows": |
| 88 | + expected_priv_key = expected_priv_key.replace("\r\n", "\n") |
| 89 | + |
| 90 | + self.assertTrue(expected_priv_key == s_priv_key, err_msg % "SSH Private Key") |
| 91 | + |
| 92 | + with open(full_path + ".pub", "r") as pub_key_file: |
| 93 | + s_pub_key = pub_key_file.read() |
| 94 | + self.assertTrue(SSH_PUBLIC_KEY == s_pub_key, err_msg % "SSH Public Key") |
64 | 95 |
|
65 | 96 |
|
66 | 97 | def _enroll_ssh_cert(connector, request): |
|
0 commit comments