SSH utility functions for key manipulation
[ganeti-github.git] / test / py / ganeti.ssh_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007, 2008 Google Inc.
5 # All rights reserved.
6 #
7 # Redistribution and use in source and binary forms, with or without
8 # modification, are permitted provided that the following conditions are
9 # met:
10 #
11 # 1. Redistributions of source code must retain the above copyright notice,
12 # this list of conditions and the following disclaimer.
13 #
14 # 2. Redistributions in binary form must reproduce the above copyright
15 # notice, this list of conditions and the following disclaimer in the
16 # documentation and/or other materials provided with the distribution.
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
19 # IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
20 # TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
21 # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
22 # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23 # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24 # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
25 # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
26 # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
27 # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
30
31 """Script for unittesting the ssh module"""
32
33 import os
34 import tempfile
35 import unittest
36 import shutil
37
38 import testutils
39 import mocks
40
41 from ganeti import constants
42 from ganeti import utils
43 from ganeti import ssh
44 from ganeti import errors
45
46
47 class TestKnownHosts(testutils.GanetiTestCase):
48 """Test case for function writing the known_hosts file"""
49
50 def setUp(self):
51 testutils.GanetiTestCase.setUp(self)
52 self.tmpfile = self._CreateTempFile()
53
54 def test(self):
55 cfg = mocks.FakeConfig()
56 ssh.WriteKnownHostsFile(cfg, self.tmpfile)
57 self.assertFileContent(self.tmpfile,
58 "%s ssh-rsa %s\n%s ssh-dss %s\n" %
59 (cfg.GetClusterName(), mocks.FAKE_CLUSTER_KEY,
60 cfg.GetClusterName(), mocks.FAKE_CLUSTER_KEY))
61
62
63 class TestGetUserFiles(unittest.TestCase):
64 def setUp(self):
65 self.tmpdir = tempfile.mkdtemp()
66
67 def tearDown(self):
68 shutil.rmtree(self.tmpdir)
69
70 @staticmethod
71 def _GetNoHomedir(_):
72 return None
73
74 def _GetTempHomedir(self, _):
75 return self.tmpdir
76
77 def testNonExistantUser(self):
78 for kind in constants.SSHK_ALL:
79 self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
80 kind=kind, _homedir_fn=self._GetNoHomedir)
81
82 def testUnknownKind(self):
83 kind = "something-else"
84 assert kind not in constants.SSHK_ALL
85 self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
86 kind=kind, _homedir_fn=self._GetTempHomedir)
87
88 self.assertEqual(os.listdir(self.tmpdir), [])
89
90 def testNoSshDirectory(self):
91 for kind in constants.SSHK_ALL:
92 self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
93 kind=kind, _homedir_fn=self._GetTempHomedir)
94 self.assertEqual(os.listdir(self.tmpdir), [])
95
96 def testSshIsFile(self):
97 utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
98 for kind in constants.SSHK_ALL:
99 self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
100 kind=kind, _homedir_fn=self._GetTempHomedir)
101 self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
102
103 def testMakeSshDirectory(self):
104 sshdir = os.path.join(self.tmpdir, ".ssh")
105
106 self.assertEqual(os.listdir(self.tmpdir), [])
107
108 for kind in constants.SSHK_ALL:
109 ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
110 _homedir_fn=self._GetTempHomedir)
111 self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
112 self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
113
114 def testFilenames(self):
115 sshdir = os.path.join(self.tmpdir, ".ssh")
116
117 os.mkdir(sshdir)
118
119 for kind in constants.SSHK_ALL:
120 result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
121 _homedir_fn=self._GetTempHomedir)
122 self.assertEqual(result, [
123 os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
124 os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind),
125 os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
126 ])
127
128 self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
129 self.assertEqual(os.listdir(sshdir), [])
130
131 def testNoDirCheck(self):
132 self.assertEqual(os.listdir(self.tmpdir), [])
133
134 for kind in constants.SSHK_ALL:
135 ssh.GetUserFiles("example14528", mkdir=False, dircheck=False, kind=kind,
136 _homedir_fn=self._GetTempHomedir)
137 self.assertEqual(os.listdir(self.tmpdir), [])
138
139 def testGetAllUserFiles(self):
140 result = ssh.GetAllUserFiles("example7475", mkdir=False, dircheck=False,
141 _homedir_fn=self._GetTempHomedir)
142 self.assertEqual(result,
143 (os.path.join(self.tmpdir, ".ssh", "authorized_keys"), {
144 constants.SSHK_RSA:
145 (os.path.join(self.tmpdir, ".ssh", "id_rsa"),
146 os.path.join(self.tmpdir, ".ssh", "id_rsa.pub")),
147 constants.SSHK_DSA:
148 (os.path.join(self.tmpdir, ".ssh", "id_dsa"),
149 os.path.join(self.tmpdir, ".ssh", "id_dsa.pub")),
150 constants.SSHK_ECDSA:
151 (os.path.join(self.tmpdir, ".ssh", "id_ecdsa"),
152 os.path.join(self.tmpdir, ".ssh", "id_ecdsa.pub")),
153 }))
154 self.assertEqual(os.listdir(self.tmpdir), [])
155
156 def testGetAllUserFilesNoDirectoryNoMkdir(self):
157 self.assertRaises(errors.OpExecError, ssh.GetAllUserFiles,
158 "example17270", mkdir=False, dircheck=True,
159 _homedir_fn=self._GetTempHomedir)
160 self.assertEqual(os.listdir(self.tmpdir), [])
161
162
163 class TestSshKeys(testutils.GanetiTestCase):
164 """Test case for the AddAuthorizedKey function"""
165
166 KEY_A = "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a"
167 KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
168 "ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b")
169
170 def setUp(self):
171 testutils.GanetiTestCase.setUp(self)
172 self.tmpname = self._CreateTempFile()
173 handle = open(self.tmpname, "w")
174 try:
175 handle.write("%s\n" % TestSshKeys.KEY_A)
176 handle.write("%s\n" % TestSshKeys.KEY_B)
177 finally:
178 handle.close()
179
180 def testHasAuthorizedKey(self):
181 self.assertTrue(ssh.HasAuthorizedKey(self.tmpname, self.KEY_A))
182 self.assertFalse(ssh.HasAuthorizedKey(
183 self.tmpname, "I am the key of the pink bunny!"))
184
185 def testAddingNewKey(self):
186 ssh.AddAuthorizedKey(self.tmpname,
187 "ssh-dss AAAAB3NzaC1kc3MAAACB root@test")
188
189 self.assertFileContent(self.tmpname,
190 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
191 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
192 " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
193 "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
194
195 def testAddingDuplicateKeys(self):
196 ssh.AddAuthorizedKey(self.tmpname,
197 "ssh-dss AAAAB3NzaC1kc3MAAACB root@test")
198 ssh.AddAuthorizedKeys(self.tmpname,
199 ["ssh-dss AAAAB3NzaC1kc3MAAACB root@test",
200 "ssh-dss AAAAB3NzaC1kc3MAAACB root@test"])
201
202 self.assertFileContent(self.tmpname,
203 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
204 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
205 " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
206 "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
207
208 def testAddingSeveralKeysAtOnce(self):
209 ssh.AddAuthorizedKeys(self.tmpname, ["aaa", "bbb", "ccc"])
210 self.assertFileContent(self.tmpname,
211 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
212 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
213 " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
214 "aaa\nbbb\nccc\n")
215 ssh.AddAuthorizedKeys(self.tmpname, ["bbb", "ddd", "eee"])
216 self.assertFileContent(self.tmpname,
217 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
218 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
219 " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
220 "aaa\nbbb\nccc\nddd\neee\n")
221
222 def testAddingAlmostButNotCompletelyTheSameKey(self):
223 ssh.AddAuthorizedKey(self.tmpname,
224 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test")
225
226 # Only significant fields are compared, therefore the key won't be
227 # updated/added
228 self.assertFileContent(self.tmpname,
229 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
230 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
231 " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
232
233 def testAddingExistingKeyWithSomeMoreSpaces(self):
234 ssh.AddAuthorizedKey(self.tmpname,
235 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a")
236 ssh.AddAuthorizedKey(self.tmpname,
237 "ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22")
238
239 self.assertFileContent(self.tmpname,
240 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
241 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
242 " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
243 "ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22\n")
244
245 def testRemovingExistingKeyWithSomeMoreSpaces(self):
246 ssh.RemoveAuthorizedKey(self.tmpname,
247 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a")
248
249 self.assertFileContent(self.tmpname,
250 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
251 " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
252
253 def testRemovingNonExistingKey(self):
254 ssh.RemoveAuthorizedKey(self.tmpname,
255 "ssh-dss AAAAB3Nsdfj230xxjxJjsjwjsjdjU root@test")
256
257 self.assertFileContent(self.tmpname,
258 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
259 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
260 " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
261
262 def testAddingNewKeys(self):
263 ssh.AddAuthorizedKeys(self.tmpname,
264 ["ssh-dss AAAAB3NzaC1kc3MAAACB root@test"])
265 self.assertFileContent(self.tmpname,
266 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
267 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
268 " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
269 "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
270
271 ssh.AddAuthorizedKeys(self.tmpname,
272 ["ssh-dss AAAAB3asdfasdfaYTUCB laracroft@test",
273 "ssh-dss AasdfliuobaosfMAAACB frodo@test"])
274 self.assertFileContent(self.tmpname,
275 "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
276 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
277 " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
278 "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n"
279 "ssh-dss AAAAB3asdfasdfaYTUCB laracroft@test\n"
280 "ssh-dss AasdfliuobaosfMAAACB frodo@test\n")
281
282 def testOtherKeyTypes(self):
283 key_rsa = "ssh-rsa AAAAimnottypingallofthathere0jfJs22 test@test"
284 key_ed25519 = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOlcZ6cpQTGow0LZECRHWn9"\
285 "7Yvn16J5un501T/RcbfuF fast@secure"
286 key_ecdsa = "ecdsa-sha2-nistp256 AAAAE2VjZHNtoolongk/TNhVbEg= secure@secure"
287
288 def _ToFileContent(keys):
289 return '\n'.join(keys) + '\n'
290
291 ssh.AddAuthorizedKeys(self.tmpname, [key_rsa, key_ed25519, key_ecdsa])
292 self.assertFileContent(self.tmpname,
293 _ToFileContent([self.KEY_A, self.KEY_B, key_rsa,
294 key_ed25519, key_ecdsa]))
295
296 ssh.RemoveAuthorizedKey(self.tmpname, key_ed25519)
297 self.assertFileContent(self.tmpname,
298 _ToFileContent([self.KEY_A, self.KEY_B, key_rsa,
299 key_ecdsa]))
300
301 ssh.RemoveAuthorizedKey(self.tmpname, key_rsa)
302 ssh.RemoveAuthorizedKey(self.tmpname, key_ecdsa)
303 self.assertFileContent(self.tmpname,
304 _ToFileContent([self.KEY_A, self.KEY_B]))
305
306
307 class TestPublicSshKeys(testutils.GanetiTestCase):
308 """Test case for the handling of the list of public ssh keys."""
309
310 KEY_A = "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a"
311 KEY_B = "ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b"
312 UUID_1 = "123-456"
313 UUID_2 = "789-ABC"
314
315 def setUp(self):
316 testutils.GanetiTestCase.setUp(self)
317
318 def testAddingAndRemovingPubKey(self):
319 pub_key_file = self._CreateTempFile()
320 ssh.AddPublicKey(self.UUID_1, self.KEY_A, key_file=pub_key_file)
321 ssh.AddPublicKey(self.UUID_2, self.KEY_B, key_file=pub_key_file)
322 self.assertFileContent(pub_key_file,
323 "123-456 ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
324 "789-ABC ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n")
325
326 ssh.RemovePublicKey(self.UUID_2, key_file=pub_key_file)
327 self.assertFileContent(pub_key_file,
328 "123-456 ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n")
329
330 def testAddingExistingPubKey(self):
331 expected_file_content = \
332 "123-456 ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n" + \
333 "789-ABC ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n"
334 pub_key_file = self._CreateTempFile()
335 ssh.AddPublicKey(self.UUID_1, self.KEY_A, key_file=pub_key_file)
336 ssh.AddPublicKey(self.UUID_2, self.KEY_B, key_file=pub_key_file)
337 self.assertFileContent(pub_key_file, expected_file_content)
338
339 ssh.AddPublicKey(self.UUID_1, self.KEY_A, key_file=pub_key_file)
340 self.assertFileContent(pub_key_file, expected_file_content)
341
342 ssh.AddPublicKey(self.UUID_1, self.KEY_B, key_file=pub_key_file)
343 self.assertFileContent(pub_key_file,
344 "123-456 ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
345 "789-ABC ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n"
346 "123-456 ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n")
347
348 def testRemoveNonexistingKey(self):
349 pub_key_file = self._CreateTempFile()
350 ssh.AddPublicKey(self.UUID_1, self.KEY_B, key_file=pub_key_file)
351 self.assertFileContent(pub_key_file,
352 "123-456 ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n")
353
354 ssh.RemovePublicKey(self.UUID_2, key_file=pub_key_file)
355 self.assertFileContent(pub_key_file,
356 "123-456 ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n")
357
358 def testRemoveAllExistingKeys(self):
359 pub_key_file = self._CreateTempFile()
360 ssh.AddPublicKey(self.UUID_1, self.KEY_A, key_file=pub_key_file)
361 ssh.AddPublicKey(self.UUID_1, self.KEY_B, key_file=pub_key_file)
362 self.assertFileContent(pub_key_file,
363 "123-456 ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
364 "123-456 ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n")
365
366 ssh.RemovePublicKey(self.UUID_1, key_file=pub_key_file)
367 self.assertFileContent(pub_key_file, "")
368
369 def testRemoveKeyFromEmptyFile(self):
370 pub_key_file = self._CreateTempFile()
371 ssh.RemovePublicKey(self.UUID_2, key_file=pub_key_file)
372 self.assertFileContent(pub_key_file, "")
373
374 def testRetrieveKeys(self):
375 pub_key_file = self._CreateTempFile()
376 ssh.AddPublicKey(self.UUID_1, self.KEY_A, key_file=pub_key_file)
377 ssh.AddPublicKey(self.UUID_2, self.KEY_B, key_file=pub_key_file)
378 result = ssh.QueryPubKeyFile(self.UUID_1, key_file=pub_key_file)
379 self.assertEquals([self.KEY_A], result[self.UUID_1])
380
381 target_uuids = [self.UUID_1, self.UUID_2, "non-existing-UUID"]
382 result = ssh.QueryPubKeyFile(target_uuids, key_file=pub_key_file)
383 self.assertEquals([self.KEY_A], result[self.UUID_1])
384 self.assertEquals([self.KEY_B], result[self.UUID_2])
385 self.assertEquals(2, len(result))
386
387 # Query all keys
388 target_uuids = None
389 result = ssh.QueryPubKeyFile(target_uuids, key_file=pub_key_file)
390 self.assertEquals([self.KEY_A], result[self.UUID_1])
391 self.assertEquals([self.KEY_B], result[self.UUID_2])
392
393 def testReplaceNameByUuid(self):
394 pub_key_file = self._CreateTempFile()
395 name = "my.precious.node"
396 ssh.AddPublicKey(name, self.KEY_A, key_file=pub_key_file)
397 ssh.AddPublicKey(self.UUID_2, self.KEY_A, key_file=pub_key_file)
398 ssh.AddPublicKey(name, self.KEY_B, key_file=pub_key_file)
399 self.assertFileContent(pub_key_file,
400 "my.precious.node ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
401 "789-ABC ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
402 "my.precious.node ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n")
403
404 ssh.ReplaceNameByUuid(self.UUID_1, name, key_file=pub_key_file)
405 self.assertFileContent(pub_key_file,
406 "123-456 ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
407 "789-ABC ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
408 "123-456 ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n")
409
410 def testParseEmptyLines(self):
411 pub_key_file = self._CreateTempFile()
412 ssh.AddPublicKey(self.UUID_1, self.KEY_A, key_file=pub_key_file)
413
414 # Add an empty line
415 fd = open(pub_key_file, 'a')
416 fd.write("\n")
417 fd.close()
418
419 ssh.AddPublicKey(self.UUID_2, self.KEY_B, key_file=pub_key_file)
420
421 # Add a whitespace line
422 fd = open(pub_key_file, 'a')
423 fd.write(" \n")
424 fd.close()
425
426 result = ssh.QueryPubKeyFile(self.UUID_1, key_file=pub_key_file)
427 self.assertEquals([self.KEY_A], result[self.UUID_1])
428
429 def testClearPubKeyFile(self):
430 pub_key_file = self._CreateTempFile()
431 ssh.AddPublicKey(self.UUID_2, self.KEY_A, key_file=pub_key_file)
432 ssh.ClearPubKeyFile(key_file=pub_key_file)
433 self.assertFileContent(pub_key_file, "")
434
435 def testOverridePubKeyFile(self):
436 pub_key_file = self._CreateTempFile()
437 key_map = {self.UUID_1: [self.KEY_A, self.KEY_B],
438 self.UUID_2: [self.KEY_A]}
439 ssh.OverridePubKeyFile(key_map, key_file=pub_key_file)
440 self.assertFileContent(pub_key_file,
441 "123-456 ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
442 "123-456 ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n"
443 "789-ABC ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n")
444
445
446 class TestGetUserFiles(testutils.GanetiTestCase):
447
448 _PRIV_KEY = "my private key"
449 _PUB_KEY = "my public key"
450 _AUTH_KEYS = "a\nb\nc"
451
452 def _setUpFakeKeys(self):
453 ssh_tmpdir = os.path.join(self.tmpdir, ".ssh")
454 os.makedirs(ssh_tmpdir)
455
456 self.priv_filename = os.path.join(ssh_tmpdir, "id_dsa")
457 utils.WriteFile(self.priv_filename, data=self._PRIV_KEY)
458
459 self.pub_filename = os.path.join(ssh_tmpdir, "id_dsa.pub")
460 utils.WriteFile(self.pub_filename, data=self._PUB_KEY)
461
462 self.auth_filename = os.path.join(ssh_tmpdir, "authorized_keys")
463 utils.WriteFile(self.auth_filename, data=self._AUTH_KEYS)
464
465 def setUp(self):
466 testutils.GanetiTestCase.setUp(self)
467 self.tmpdir = tempfile.mkdtemp()
468 self._setUpFakeKeys()
469
470 def tearDown(self):
471 shutil.rmtree(self.tmpdir)
472
473 def _GetTempHomedir(self, _):
474 return self.tmpdir
475
476 def testNewKeysOverrideOldKeys(self):
477 ssh.InitSSHSetup("dsa", 1024, _homedir_fn=self._GetTempHomedir)
478 self.assertFileContentNotEqual(self.priv_filename, self._PRIV_KEY)
479 self.assertFileContentNotEqual(self.pub_filename, self._PUB_KEY)
480
481 def testSuffix(self):
482 suffix = "_pinkbunny"
483 ssh.InitSSHSetup("dsa", 1024, _homedir_fn=self._GetTempHomedir,
484 _suffix=suffix)
485 self.assertFileContent(self.priv_filename, self._PRIV_KEY)
486 self.assertFileContent(self.pub_filename, self._PUB_KEY)
487 self.assertTrue(os.path.exists(self.priv_filename + suffix))
488 self.assertTrue(os.path.exists(self.priv_filename + suffix + ".pub"))
489
490
491 class TestDetermineKeyBits(testutils.GanetiTestCase):
492 def testCompleteness(self):
493 self.assertEquals(constants.SSHK_ALL,
494 frozenset(ssh.SSH_KEY_VALID_BITS.keys()))
495
496 def testAdoptDefault(self):
497 self.assertEquals(2048, ssh.DetermineKeyBits("rsa", None, None, None))
498 self.assertEquals(1024, ssh.DetermineKeyBits("dsa", None, None, None))
499
500 def testAdoptOldKeySize(self):
501 self.assertEquals(4098, ssh.DetermineKeyBits("rsa", None, "rsa", 4098))
502 self.assertEquals(2048, ssh.DetermineKeyBits("rsa", None, "dsa", 1024))
503
504 def testDsaSpecificValues(self):
505 self.assertRaises(errors.OpPrereqError, ssh.DetermineKeyBits, "dsa", 2048,
506 None, None)
507 self.assertRaises(errors.OpPrereqError, ssh.DetermineKeyBits, "dsa", 512,
508 None, None)
509 self.assertEquals(1024, ssh.DetermineKeyBits("dsa", None, None, None))
510
511 def testEcdsaSpecificValues(self):
512 self.assertRaises(errors.OpPrereqError, ssh.DetermineKeyBits, "ecdsa", 2048,
513 None, None)
514 for b in [256, 384, 521]:
515 self.assertEquals(b, ssh.DetermineKeyBits("ecdsa", b, None, None))
516
517 def testRsaSpecificValues(self):
518 self.assertRaises(errors.OpPrereqError, ssh.DetermineKeyBits, "dsa", 766,
519 None, None)
520 for b in [768, 769, 2048, 2049, 4096]:
521 self.assertEquals(b, ssh.DetermineKeyBits("rsa", b, None, None))
522
523
524 class TestManageLocalSshPubKeys(testutils.GanetiTestCase):
525 """Test class for several methods handling local SSH keys.
526
527 Methods covered are:
528 - GetSshKeyFilenames
529 - GetSshPubKeyFilename
530 - ReplaceSshKeys
531 - ReadLocalSshPubKeys
532
533 These methods are covered in one test, because the preparations for
534 their tests is identical and thus can be reused.
535
536 """
537 VISIBILITY_PRIVATE = "private"
538 VISIBILITY_PUBLIC = "public"
539 VISIBILITIES = frozenset([VISIBILITY_PRIVATE, VISIBILITY_PUBLIC])
540
541 def _GenerateKey(self, key_id, visibility):
542 assert visibility in self.VISIBILITIES
543 return "I am the %s %s SSH key." % (visibility, key_id)
544
545 def _GetKeyPath(self, key_file_basename):
546 return os.path.join(self.tmpdir, key_file_basename)
547
548 def _SetUpKeys(self):
549 """Creates a fake SSH key for each type and with/without suffix."""
550 self._key_file_dict = {}
551 for key_type in constants.SSHK_ALL:
552 for suffix in ["", self._suffix]:
553 pub_key_filename = "id_%s%s.pub" % (key_type, suffix)
554 priv_key_filename = "id_%s%s" % (key_type, suffix)
555
556 pub_key_path = self._GetKeyPath(pub_key_filename)
557 priv_key_path = self._GetKeyPath(priv_key_filename)
558
559 utils.WriteFile(
560 priv_key_path,
561 data=self._GenerateKey(key_type + suffix, self.VISIBILITY_PRIVATE))
562
563 utils.WriteFile(
564 pub_key_path,
565 data=self._GenerateKey(key_type + suffix, self.VISIBILITY_PUBLIC))
566
567 # Fill key dict only for non-suffix keys
568 # (as this is how it will be in the code)
569 if not suffix:
570 self._key_file_dict[key_type] = \
571 (priv_key_path, pub_key_path)
572
573 def setUp(self):
574 testutils.GanetiTestCase.setUp(self)
575 self.tmpdir = tempfile.mkdtemp()
576 self._suffix = "_suffix"
577 self._SetUpKeys()
578
579 def tearDown(self):
580 shutil.rmtree(self.tmpdir)
581
582 @testutils.patch_object(ssh, "GetAllUserFiles")
583 def testReadAllPublicKeyFiles(self, mock_getalluserfiles):
584 mock_getalluserfiles.return_value = (None, self._key_file_dict)
585
586 keys = ssh.ReadLocalSshPubKeys([], suffix="")
587
588 self.assertEqual(len(constants.SSHK_ALL), len(keys))
589 for key_type in constants.SSHK_ALL:
590 self.assertTrue(
591 self._GenerateKey(key_type, self.VISIBILITY_PUBLIC) in keys)
592
593 @testutils.patch_object(ssh, "GetAllUserFiles")
594 def testReadOnePublicKeyFile(self, mock_getalluserfiles):
595 mock_getalluserfiles.return_value = (None, self._key_file_dict)
596
597 keys = ssh.ReadLocalSshPubKeys([constants.SSHK_DSA], suffix="")
598
599 self.assertEqual(1, len(keys))
600 self.assertEqual(
601 self._GenerateKey(constants.SSHK_DSA, self.VISIBILITY_PUBLIC),
602 keys[0])
603
604 @testutils.patch_object(ssh, "GetAllUserFiles")
605 def testReadPublicKeyFilesWithSuffix(self, mock_getalluserfiles):
606 key_types = [constants.SSHK_DSA, constants.SSHK_ECDSA]
607
608 mock_getalluserfiles.return_value = (None, self._key_file_dict)
609
610 keys = ssh.ReadLocalSshPubKeys(key_types, suffix=self._suffix)
611
612 self.assertEqual(2, len(keys))
613 for key_id in [key_type + self._suffix for key_type in key_types]:
614 self.assertTrue(
615 self._GenerateKey(key_id, self.VISIBILITY_PUBLIC) in keys)
616
617 @testutils.patch_object(ssh, "GetAllUserFiles")
618 def testGetSshKeyFilenames(self, mock_getalluserfiles):
619 mock_getalluserfiles.return_value = (None, self._key_file_dict)
620
621 priv, pub = ssh.GetSshKeyFilenames(constants.SSHK_DSA)
622
623 self.assertEqual("id_dsa", os.path.basename(priv))
624 self.assertNotEqual("id_dsa", priv)
625 self.assertEqual("id_dsa.pub", os.path.basename(pub))
626 self.assertNotEqual("id_dsa.pub", pub)
627
628 @testutils.patch_object(ssh, "GetAllUserFiles")
629 def testGetSshKeyFilenamesWithSuffix(self, mock_getalluserfiles):
630 mock_getalluserfiles.return_value = (None, self._key_file_dict)
631
632 priv, pub = ssh.GetSshKeyFilenames(constants.SSHK_RSA, suffix=self._suffix)
633
634 self.assertEqual("id_rsa_suffix", os.path.basename(priv))
635 self.assertNotEqual("id_rsa_suffix", priv)
636 self.assertEqual("id_rsa_suffix.pub", os.path.basename(pub))
637 self.assertNotEqual("id_rsa_suffix.pub", pub)
638
639 @testutils.patch_object(ssh, "GetAllUserFiles")
640 def testGetPubSshKeyFilename(self, mock_getalluserfiles):
641 mock_getalluserfiles.return_value = (None, self._key_file_dict)
642
643 pub = ssh.GetSshPubKeyFilename(constants.SSHK_DSA)
644 pub_suffix = ssh.GetSshPubKeyFilename(
645 constants.SSHK_DSA, suffix=self._suffix)
646
647 self.assertEqual("id_dsa.pub", os.path.basename(pub))
648 self.assertNotEqual("id_dsa.pub", pub)
649 self.assertEqual("id_dsa_suffix.pub", os.path.basename(pub_suffix))
650 self.assertNotEqual("id_dsa_suffix.pub", pub_suffix)
651
652 @testutils.patch_object(ssh, "GetAllUserFiles")
653 def testReplaceSshKeys(self, mock_getalluserfiles):
654 """Replace SSH keys without suffixes.
655
656 Note: usually it does not really make sense to replace the DSA key
657 by the RSA key. This is just to test the function without suffixes.
658
659 """
660 mock_getalluserfiles.return_value = (None, self._key_file_dict)
661
662 ssh.ReplaceSshKeys(constants.SSHK_RSA, constants.SSHK_DSA)
663
664 priv_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][0])
665 pub_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][1])
666
667 self.assertEqual("I am the private rsa SSH key.", priv_key)
668 self.assertEqual("I am the public rsa SSH key.", pub_key)
669
670 @testutils.patch_object(ssh, "GetAllUserFiles")
671 def testReplaceSshKeysBySuffixedKeys(self, mock_getalluserfiles):
672 """Replace SSH keys with keys from suffixed files.
673
674 Note: usually it does not really make sense to replace the DSA key
675 by the RSA key. This is just to test the function without suffixes.
676
677 """
678 mock_getalluserfiles.return_value = (None, self._key_file_dict)
679
680 ssh.ReplaceSshKeys(constants.SSHK_DSA, constants.SSHK_DSA,
681 src_key_suffix=self._suffix)
682
683 priv_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][0])
684 pub_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][1])
685
686 self.assertEqual("I am the private dsa_suffix SSH key.", priv_key)
687 self.assertEqual("I am the public dsa_suffix SSH key.", pub_key)
688
689
690 if __name__ == "__main__":
691 testutils.GanetiTestProgram()