Renew SSH keys and upgrade
[ganeti-github.git] / test / py / ganeti.backend_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010, 2013 Google Inc.
5 # All rights reserved.
6 #
7 # Redistribution and use in source and binary forms, with or without
8 # modification, are permitted provided that the following conditions are
9 # met:
10 #
11 # 1. Redistributions of source code must retain the above copyright notice,
12 # this list of conditions and the following disclaimer.
13 #
14 # 2. Redistributions in binary form must reproduce the above copyright
15 # notice, this list of conditions and the following disclaimer in the
16 # documentation and/or other materials provided with the distribution.
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
19 # IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
20 # TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
21 # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
22 # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23 # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24 # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
25 # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
26 # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
27 # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
30
31 """Script for testing ganeti.backend"""
32
33 import copy
34 import mock
35 import os
36 import shutil
37 import tempfile
38 import testutils
39 import unittest
40
41 from ganeti import backend
42 from ganeti import constants
43 from ganeti import errors
44 from ganeti import hypervisor
45 from ganeti import netutils
46 from ganeti import objects
47 from ganeti import pathutils
48 from ganeti import ssh
49 from ganeti import utils
50
51
52 class TestX509Certificates(unittest.TestCase):
53 def setUp(self):
54 self.tmpdir = tempfile.mkdtemp()
55
56 def tearDown(self):
57 shutil.rmtree(self.tmpdir)
58
59 def test(self):
60 (name, cert_pem) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
61
62 self.assertEqual(utils.ReadFile(os.path.join(self.tmpdir, name,
63 backend._X509_CERT_FILE)),
64 cert_pem)
65 self.assert_(0 < os.path.getsize(os.path.join(self.tmpdir, name,
66 backend._X509_KEY_FILE)))
67
68 (name2, cert_pem2) = \
69 backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
70
71 backend.RemoveX509Certificate(name, cryptodir=self.tmpdir)
72 backend.RemoveX509Certificate(name2, cryptodir=self.tmpdir)
73
74 self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [])
75
76 def testNonEmpty(self):
77 (name, _) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
78
79 utils.WriteFile(utils.PathJoin(self.tmpdir, name, "hello-world"),
80 data="Hello World")
81
82 self.assertRaises(backend.RPCFail, backend.RemoveX509Certificate,
83 name, cryptodir=self.tmpdir)
84
85 self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [name])
86
87
88 class TestGetCryptoTokens(testutils.GanetiTestCase):
89
90 def setUp(self):
91 self._get_digest_fn_orig = utils.GetCertificateDigest
92 self._create_digest_fn_orig = utils.GenerateNewSslCert
93 self._ssl_digest = "12345"
94 utils.GetCertificateDigest = mock.Mock(
95 return_value=self._ssl_digest)
96 utils.GenerateNewSslCert = mock.Mock()
97
98 def tearDown(self):
99 utils.GetCertificateDigest = self._get_digest_fn_orig
100 utils.GenerateNewSslCert = self._create_digest_fn_orig
101
102 def testGetSslToken(self):
103 result = backend.GetCryptoTokens(
104 [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_GET, None)])
105 self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
106 in result)
107
108 def testCreateSslToken(self):
109 result = backend.GetCryptoTokens(
110 [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_CREATE,
111 {constants.CRYPTO_OPTION_SERIAL_NO: 42})])
112 self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
113 in result)
114 self.assertTrue(utils.GenerateNewSslCert.assert_calls().once())
115
116 def testCreateSslTokenDifferentFilename(self):
117 result = backend.GetCryptoTokens(
118 [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_CREATE,
119 {constants.CRYPTO_OPTION_CERT_FILE:
120 pathutils.NODED_CLIENT_CERT_FILE_TMP,
121 constants.CRYPTO_OPTION_SERIAL_NO: 42})])
122 self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
123 in result)
124 self.assertTrue(utils.GenerateNewSslCert.assert_calls().once())
125
126 def testCreateSslTokenSerialNo(self):
127 result = backend.GetCryptoTokens(
128 [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_CREATE,
129 {constants.CRYPTO_OPTION_SERIAL_NO: 42})])
130 self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
131 in result)
132 self.assertTrue(utils.GenerateNewSslCert.assert_calls().once())
133
134 def testUnknownTokenType(self):
135 self.assertRaises(errors.ProgrammerError,
136 backend.GetCryptoTokens,
137 [("pink_bunny", constants.CRYPTO_ACTION_GET, None)])
138
139 def testUnknownAction(self):
140 self.assertRaises(errors.ProgrammerError,
141 backend.GetCryptoTokens,
142 [(constants.CRYPTO_TYPE_SSL_DIGEST, "illuminate", None)])
143
144
145 class TestNodeVerify(testutils.GanetiTestCase):
146
147 def setUp(self):
148 testutils.GanetiTestCase.setUp(self)
149 self._mock_hv = None
150
151 def _GetHypervisor(self, hv_name):
152 self._mock_hv = hypervisor.GetHypervisor(hv_name)
153 self._mock_hv.ValidateParameters = mock.Mock()
154 self._mock_hv.Verify = mock.Mock()
155 return self._mock_hv
156
157 def testMasterIPLocalhost(self):
158 # this a real functional test, but requires localhost to be reachable
159 local_data = (netutils.Hostname.GetSysName(),
160 constants.IP4_ADDRESS_LOCALHOST)
161 result = backend.VerifyNode({constants.NV_MASTERIP: local_data},
162 None, {}, {}, {})
163 self.failUnless(constants.NV_MASTERIP in result,
164 "Master IP data not returned")
165 self.failUnless(result[constants.NV_MASTERIP], "Cannot reach localhost")
166
167 def testMasterIPUnreachable(self):
168 # Network 192.0.2.0/24 is reserved for test/documentation as per
169 # RFC 5737
170 bad_data = ("master.example.com", "192.0.2.1")
171 # we just test that whatever TcpPing returns, VerifyNode returns too
172 netutils.TcpPing = lambda a, b, source=None: False
173 result = backend.VerifyNode({constants.NV_MASTERIP: bad_data},
174 None, {}, {}, {})
175 self.failUnless(constants.NV_MASTERIP in result,
176 "Master IP data not returned")
177 self.failIf(result[constants.NV_MASTERIP],
178 "Result from netutils.TcpPing corrupted")
179
180 def testVerifyHvparams(self):
181 test_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
182 test_what = {constants.NV_HVPARAMS: \
183 [("mynode", constants.HT_XEN_PVM, test_hvparams)]}
184 result = {}
185 backend._VerifyHvparams(test_what, True, result,
186 get_hv_fn=self._GetHypervisor)
187 self._mock_hv.ValidateParameters.assert_called_with(test_hvparams)
188
189 def testVerifyHypervisors(self):
190 hvname = constants.HT_XEN_PVM
191 hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
192 all_hvparams = {hvname: hvparams}
193 test_what = {constants.NV_HYPERVISOR: [hvname]}
194 result = {}
195 backend._VerifyHypervisors(
196 test_what, True, result, all_hvparams=all_hvparams,
197 get_hv_fn=self._GetHypervisor)
198 self._mock_hv.Verify.assert_called_with(hvparams=hvparams)
199
200 @testutils.patch_object(utils, "VerifyCertificate")
201 def testVerifyClientCertificateSuccess(self, verif_cert):
202 # mock the underlying x509 verification because the test cert is expired
203 verif_cert.return_value = (None, None)
204 cert_file = testutils.TestDataFilename("cert2.pem")
205 (errcode, digest) = backend._VerifyClientCertificate(cert_file=cert_file)
206 self.assertEqual(None, errcode)
207 self.assertTrue(isinstance(digest, str))
208
209 @testutils.patch_object(utils, "VerifyCertificate")
210 def testVerifyClientCertificateFailed(self, verif_cert):
211 expected_errcode = 666
212 verif_cert.return_value = (expected_errcode,
213 "The devil created this certificate.")
214 cert_file = testutils.TestDataFilename("cert2.pem")
215 (errcode, digest) = backend._VerifyClientCertificate(cert_file=cert_file)
216 self.assertEqual(expected_errcode, errcode)
217
218 def testVerifyClientCertificateNoCert(self):
219 cert_file = testutils.TestDataFilename("cert-that-does-not-exist.pem")
220 (errcode, digest) = backend._VerifyClientCertificate(cert_file=cert_file)
221 self.assertEqual(constants.CV_ERROR, errcode)
222
223
224 def _DefRestrictedCmdOwner():
225 return (os.getuid(), os.getgid())
226
227
228 class TestVerifyRestrictedCmdName(unittest.TestCase):
229 def testAcceptableName(self):
230 for i in ["foo", "bar", "z1", "000first", "hello-world"]:
231 for fn in [lambda s: s, lambda s: s.upper(), lambda s: s.title()]:
232 (status, msg) = backend._VerifyRestrictedCmdName(fn(i))
233 self.assertTrue(status)
234 self.assertTrue(msg is None)
235
236 def testEmptyAndSpace(self):
237 for i in ["", " ", "\t", "\n"]:
238 (status, msg) = backend._VerifyRestrictedCmdName(i)
239 self.assertFalse(status)
240 self.assertEqual(msg, "Missing command name")
241
242 def testNameWithSlashes(self):
243 for i in ["/", "./foo", "../moo", "some/name"]:
244 (status, msg) = backend._VerifyRestrictedCmdName(i)
245 self.assertFalse(status)
246 self.assertEqual(msg, "Invalid command name")
247
248 def testForbiddenCharacters(self):
249 for i in ["#", ".", "..", "bash -c ls", "'"]:
250 (status, msg) = backend._VerifyRestrictedCmdName(i)
251 self.assertFalse(status)
252 self.assertEqual(msg, "Command name contains forbidden characters")
253
254
255 class TestVerifyRestrictedCmdDirectory(unittest.TestCase):
256 def setUp(self):
257 self.tmpdir = tempfile.mkdtemp()
258
259 def tearDown(self):
260 shutil.rmtree(self.tmpdir)
261
262 def testCanNotStat(self):
263 tmpname = utils.PathJoin(self.tmpdir, "foobar")
264 self.assertFalse(os.path.exists(tmpname))
265 (status, msg) = \
266 backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
267 self.assertFalse(status)
268 self.assertTrue(msg.startswith("Can't stat(2) '"))
269
270 def testTooPermissive(self):
271 tmpname = utils.PathJoin(self.tmpdir, "foobar")
272 os.mkdir(tmpname)
273
274 for mode in [0777, 0706, 0760, 0722]:
275 os.chmod(tmpname, mode)
276 self.assertTrue(os.path.isdir(tmpname))
277 (status, msg) = \
278 backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
279 self.assertFalse(status)
280 self.assertTrue(msg.startswith("Permissions on '"))
281
282 def testNoDirectory(self):
283 tmpname = utils.PathJoin(self.tmpdir, "foobar")
284 utils.WriteFile(tmpname, data="empty\n")
285 self.assertTrue(os.path.isfile(tmpname))
286 (status, msg) = \
287 backend._VerifyRestrictedCmdDirectory(tmpname,
288 _owner=_DefRestrictedCmdOwner())
289 self.assertFalse(status)
290 self.assertTrue(msg.endswith("is not a directory"))
291
292 def testNormal(self):
293 tmpname = utils.PathJoin(self.tmpdir, "foobar")
294 os.mkdir(tmpname)
295 os.chmod(tmpname, 0755)
296 self.assertTrue(os.path.isdir(tmpname))
297 (status, msg) = \
298 backend._VerifyRestrictedCmdDirectory(tmpname,
299 _owner=_DefRestrictedCmdOwner())
300 self.assertTrue(status)
301 self.assertTrue(msg is None)
302
303
304 class TestVerifyRestrictedCmd(unittest.TestCase):
305 def setUp(self):
306 self.tmpdir = tempfile.mkdtemp()
307
308 def tearDown(self):
309 shutil.rmtree(self.tmpdir)
310
311 def testCanNotStat(self):
312 tmpname = utils.PathJoin(self.tmpdir, "helloworld")
313 self.assertFalse(os.path.exists(tmpname))
314 (status, msg) = \
315 backend._VerifyRestrictedCmd(self.tmpdir, "helloworld",
316 _owner=NotImplemented)
317 self.assertFalse(status)
318 self.assertTrue(msg.startswith("Can't stat(2) '"))
319
320 def testNotExecutable(self):
321 tmpname = utils.PathJoin(self.tmpdir, "cmdname")
322 utils.WriteFile(tmpname, data="empty\n")
323 (status, msg) = \
324 backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
325 _owner=_DefRestrictedCmdOwner())
326 self.assertFalse(status)
327 self.assertTrue(msg.startswith("access(2) thinks '"))
328
329 def testExecutable(self):
330 tmpname = utils.PathJoin(self.tmpdir, "cmdname")
331 utils.WriteFile(tmpname, data="empty\n", mode=0700)
332 (status, executable) = \
333 backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
334 _owner=_DefRestrictedCmdOwner())
335 self.assertTrue(status)
336 self.assertEqual(executable, tmpname)
337
338
339 class TestPrepareRestrictedCmd(unittest.TestCase):
340 _TEST_PATH = "/tmp/some/test/path"
341
342 def testDirFails(self):
343 def fn(path):
344 self.assertEqual(path, self._TEST_PATH)
345 return (False, "test error 31420")
346
347 (status, msg) = \
348 backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd21152",
349 _verify_dir=fn,
350 _verify_name=NotImplemented,
351 _verify_cmd=NotImplemented)
352 self.assertFalse(status)
353 self.assertEqual(msg, "test error 31420")
354
355 def testNameFails(self):
356 def fn(cmd):
357 self.assertEqual(cmd, "cmd4617")
358 return (False, "test error 591")
359
360 (status, msg) = \
361 backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd4617",
362 _verify_dir=lambda _: (True, None),
363 _verify_name=fn,
364 _verify_cmd=NotImplemented)
365 self.assertFalse(status)
366 self.assertEqual(msg, "test error 591")
367
368 def testCommandFails(self):
369 def fn(path, cmd):
370 self.assertEqual(path, self._TEST_PATH)
371 self.assertEqual(cmd, "cmd17577")
372 return (False, "test error 25524")
373
374 (status, msg) = \
375 backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd17577",
376 _verify_dir=lambda _: (True, None),
377 _verify_name=lambda _: (True, None),
378 _verify_cmd=fn)
379 self.assertFalse(status)
380 self.assertEqual(msg, "test error 25524")
381
382 def testSuccess(self):
383 def fn(path, cmd):
384 return (True, utils.PathJoin(path, cmd))
385
386 (status, executable) = \
387 backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd22633",
388 _verify_dir=lambda _: (True, None),
389 _verify_name=lambda _: (True, None),
390 _verify_cmd=fn)
391 self.assertTrue(status)
392 self.assertEqual(executable, utils.PathJoin(self._TEST_PATH, "cmd22633"))
393
394
395 def _SleepForRestrictedCmd(duration):
396 assert duration > 5
397
398
399 def _GenericRestrictedCmdError(cmd):
400 return "Executing command '%s' failed" % cmd
401
402
403 class TestRunRestrictedCmd(unittest.TestCase):
404 def setUp(self):
405 self.tmpdir = tempfile.mkdtemp()
406
407 def tearDown(self):
408 shutil.rmtree(self.tmpdir)
409
410 def testNonExistantLockDirectory(self):
411 lockfile = utils.PathJoin(self.tmpdir, "does", "not", "exist")
412 sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
413 self.assertFalse(os.path.exists(lockfile))
414 self.assertRaises(backend.RPCFail,
415 backend.RunRestrictedCmd, "test",
416 _lock_timeout=NotImplemented,
417 _lock_file=lockfile,
418 _path=NotImplemented,
419 _sleep_fn=sleep_fn,
420 _prepare_fn=NotImplemented,
421 _runcmd_fn=NotImplemented,
422 _enabled=True)
423 self.assertEqual(sleep_fn.Count(), 1)
424
425 @staticmethod
426 def _TryLock(lockfile):
427 sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
428
429 result = False
430 try:
431 backend.RunRestrictedCmd("test22717",
432 _lock_timeout=0.1,
433 _lock_file=lockfile,
434 _path=NotImplemented,
435 _sleep_fn=sleep_fn,
436 _prepare_fn=NotImplemented,
437 _runcmd_fn=NotImplemented,
438 _enabled=True)
439 except backend.RPCFail, err:
440 assert str(err) == _GenericRestrictedCmdError("test22717"), \
441 "Did not fail with generic error message"
442 result = True
443
444 assert sleep_fn.Count() == 1
445
446 return result
447
448 def testLockHeldByOtherProcess(self):
449 lockfile = utils.PathJoin(self.tmpdir, "lock")
450
451 lock = utils.FileLock.Open(lockfile)
452 lock.Exclusive(blocking=True, timeout=1.0)
453 try:
454 self.assertTrue(utils.RunInSeparateProcess(self._TryLock, lockfile))
455 finally:
456 lock.Close()
457
458 @staticmethod
459 def _PrepareRaisingException(path, cmd):
460 assert cmd == "test23122"
461 raise Exception("test")
462
463 def testPrepareRaisesException(self):
464 lockfile = utils.PathJoin(self.tmpdir, "lock")
465
466 sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
467 prepare_fn = testutils.CallCounter(self._PrepareRaisingException)
468
469 try:
470 backend.RunRestrictedCmd("test23122",
471 _lock_timeout=1.0, _lock_file=lockfile,
472 _path=NotImplemented, _runcmd_fn=NotImplemented,
473 _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
474 _enabled=True)
475 except backend.RPCFail, err:
476 self.assertEqual(str(err), _GenericRestrictedCmdError("test23122"))
477 else:
478 self.fail("Didn't fail")
479
480 self.assertEqual(sleep_fn.Count(), 1)
481 self.assertEqual(prepare_fn.Count(), 1)
482
483 @staticmethod
484 def _PrepareFails(path, cmd):
485 assert cmd == "test29327"
486 return ("some error message", None)
487
488 def testPrepareFails(self):
489 lockfile = utils.PathJoin(self.tmpdir, "lock")
490
491 sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
492 prepare_fn = testutils.CallCounter(self._PrepareFails)
493
494 try:
495 backend.RunRestrictedCmd("test29327",
496 _lock_timeout=1.0, _lock_file=lockfile,
497 _path=NotImplemented, _runcmd_fn=NotImplemented,
498 _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
499 _enabled=True)
500 except backend.RPCFail, err:
501 self.assertEqual(str(err), _GenericRestrictedCmdError("test29327"))
502 else:
503 self.fail("Didn't fail")
504
505 self.assertEqual(sleep_fn.Count(), 1)
506 self.assertEqual(prepare_fn.Count(), 1)
507
508 @staticmethod
509 def _SuccessfulPrepare(path, cmd):
510 return (True, utils.PathJoin(path, cmd))
511
512 def testRunCmdFails(self):
513 lockfile = utils.PathJoin(self.tmpdir, "lock")
514
515 def fn(args, env=NotImplemented, reset_env=NotImplemented,
516 postfork_fn=NotImplemented):
517 self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test3079")])
518 self.assertEqual(env, {})
519 self.assertTrue(reset_env)
520 self.assertTrue(callable(postfork_fn))
521
522 trylock = utils.FileLock.Open(lockfile)
523 try:
524 # See if lockfile is still held
525 self.assertRaises(EnvironmentError, trylock.Exclusive, blocking=False)
526
527 # Call back to release lock
528 postfork_fn(NotImplemented)
529
530 # See if lockfile can be acquired
531 trylock.Exclusive(blocking=False)
532 finally:
533 trylock.Close()
534
535 # Simulate a failed command
536 return utils.RunResult(constants.EXIT_FAILURE, None,
537 "stdout", "stderr406328567",
538 utils.ShellQuoteArgs(args),
539 NotImplemented, NotImplemented)
540
541 sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
542 prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
543 runcmd_fn = testutils.CallCounter(fn)
544
545 try:
546 backend.RunRestrictedCmd("test3079",
547 _lock_timeout=1.0, _lock_file=lockfile,
548 _path=self.tmpdir, _runcmd_fn=runcmd_fn,
549 _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
550 _enabled=True)
551 except backend.RPCFail, err:
552 self.assertTrue(str(err).startswith("Restricted command 'test3079'"
553 " failed:"))
554 self.assertTrue("stderr406328567" in str(err),
555 msg="Error did not include output")
556 else:
557 self.fail("Didn't fail")
558
559 self.assertEqual(sleep_fn.Count(), 0)
560 self.assertEqual(prepare_fn.Count(), 1)
561 self.assertEqual(runcmd_fn.Count(), 1)
562
563 def testRunCmdSucceeds(self):
564 lockfile = utils.PathJoin(self.tmpdir, "lock")
565
566 def fn(args, env=NotImplemented, reset_env=NotImplemented,
567 postfork_fn=NotImplemented):
568 self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test5667")])
569 self.assertEqual(env, {})
570 self.assertTrue(reset_env)
571
572 # Call back to release lock
573 postfork_fn(NotImplemented)
574
575 # Simulate a successful command
576 return utils.RunResult(constants.EXIT_SUCCESS, None, "stdout14463", "",
577 utils.ShellQuoteArgs(args),
578 NotImplemented, NotImplemented)
579
580 sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
581 prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
582 runcmd_fn = testutils.CallCounter(fn)
583
584 result = backend.RunRestrictedCmd("test5667",
585 _lock_timeout=1.0, _lock_file=lockfile,
586 _path=self.tmpdir, _runcmd_fn=runcmd_fn,
587 _sleep_fn=sleep_fn,
588 _prepare_fn=prepare_fn,
589 _enabled=True)
590 self.assertEqual(result, "stdout14463")
591
592 self.assertEqual(sleep_fn.Count(), 0)
593 self.assertEqual(prepare_fn.Count(), 1)
594 self.assertEqual(runcmd_fn.Count(), 1)
595
596 def testCommandsDisabled(self):
597 try:
598 backend.RunRestrictedCmd("test",
599 _lock_timeout=NotImplemented,
600 _lock_file=NotImplemented,
601 _path=NotImplemented,
602 _sleep_fn=NotImplemented,
603 _prepare_fn=NotImplemented,
604 _runcmd_fn=NotImplemented,
605 _enabled=False)
606 except backend.RPCFail, err:
607 self.assertEqual(str(err),
608 "Restricted commands disabled at configure time")
609 else:
610 self.fail("Did not raise exception")
611
612
613 class TestSetWatcherPause(unittest.TestCase):
614 def setUp(self):
615 self.tmpdir = tempfile.mkdtemp()
616 self.filename = utils.PathJoin(self.tmpdir, "pause")
617
618 def tearDown(self):
619 shutil.rmtree(self.tmpdir)
620
621 def testUnsetNonExisting(self):
622 self.assertFalse(os.path.exists(self.filename))
623 backend.SetWatcherPause(None, _filename=self.filename)
624 self.assertFalse(os.path.exists(self.filename))
625
626 def testSetNonNumeric(self):
627 for i in ["", [], {}, "Hello World", "0", "1.0"]:
628 self.assertFalse(os.path.exists(self.filename))
629
630 try:
631 backend.SetWatcherPause(i, _filename=self.filename)
632 except backend.RPCFail, err:
633 self.assertEqual(str(err), "Duration must be numeric")
634 else:
635 self.fail("Did not raise exception")
636
637 self.assertFalse(os.path.exists(self.filename))
638
639 def testSet(self):
640 self.assertFalse(os.path.exists(self.filename))
641
642 for i in range(10):
643 backend.SetWatcherPause(i, _filename=self.filename)
644 self.assertEqual(utils.ReadFile(self.filename), "%s\n" % i)
645 self.assertEqual(os.stat(self.filename).st_mode & 0777, 0644)
646
647
648 class TestGetBlockDevSymlinkPath(unittest.TestCase):
649 def setUp(self):
650 self.tmpdir = tempfile.mkdtemp()
651
652 def tearDown(self):
653 shutil.rmtree(self.tmpdir)
654
655 def _Test(self, name, idx):
656 self.assertEqual(backend._GetBlockDevSymlinkPath(name, idx,
657 _dir=self.tmpdir),
658 ("%s/%s%s%s" % (self.tmpdir, name,
659 constants.DISK_SEPARATOR, idx)))
660
661 def test(self):
662 for idx in range(100):
663 self._Test("inst1.example.com", idx)
664
665
666 class TestGetInstanceList(unittest.TestCase):
667
668 def setUp(self):
669 self._test_hv = self._TestHypervisor()
670 self._test_hv.ListInstances = mock.Mock(
671 return_value=["instance1", "instance2", "instance3"] )
672
673 class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
674 def __init__(self):
675 hypervisor.hv_base.BaseHypervisor.__init__(self)
676
677 def _GetHypervisor(self, name):
678 return self._test_hv
679
680 def testHvparams(self):
681 fake_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
682 hvparams = {constants.HT_FAKE: fake_hvparams}
683 backend.GetInstanceList([constants.HT_FAKE], all_hvparams=hvparams,
684 get_hv_fn=self._GetHypervisor)
685 self._test_hv.ListInstances.assert_called_with(hvparams=fake_hvparams)
686
687
688 class TestInstanceConsoleInfo(unittest.TestCase):
689
690 def setUp(self):
691 self._test_hv_a = self._TestHypervisor()
692 self._test_hv_a.GetInstanceConsole = mock.Mock(
693 return_value = objects.InstanceConsole(instance="inst", kind="aHy")
694 )
695 self._test_hv_b = self._TestHypervisor()
696 self._test_hv_b.GetInstanceConsole = mock.Mock(
697 return_value = objects.InstanceConsole(instance="inst", kind="bHy")
698 )
699
700 class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
701 def __init__(self):
702 hypervisor.hv_base.BaseHypervisor.__init__(self)
703
704 def _GetHypervisor(self, name):
705 if name == "a":
706 return self._test_hv_a
707 else:
708 return self._test_hv_b
709
710 def testRightHypervisor(self):
711 dictMaker = lambda hyName: {
712 "instance":{"hypervisor":hyName},
713 "node":{},
714 "group":{},
715 "hvParams":{},
716 "beParams":{},
717 }
718
719 call = {
720 'i1':dictMaker("a"),
721 'i2':dictMaker("b"),
722 }
723
724 res = backend.GetInstanceConsoleInfo(call, get_hv_fn=self._GetHypervisor)
725
726 self.assertTrue(res["i1"]["kind"] == "aHy")
727 self.assertTrue(res["i2"]["kind"] == "bHy")
728
729
730 class TestGetHvInfo(unittest.TestCase):
731
732 def setUp(self):
733 self._test_hv = self._TestHypervisor()
734 self._test_hv.GetNodeInfo = mock.Mock()
735
736 class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
737 def __init__(self):
738 hypervisor.hv_base.BaseHypervisor.__init__(self)
739
740 def _GetHypervisor(self, name):
741 return self._test_hv
742
743 def testGetHvInfoAllNone(self):
744 result = backend._GetHvInfoAll(None)
745 self.assertTrue(result is None)
746
747 def testGetHvInfoAll(self):
748 hvname = constants.HT_XEN_PVM
749 hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
750 hv_specs = [(hvname, hvparams)]
751
752 backend._GetHvInfoAll(hv_specs, self._GetHypervisor)
753 self._test_hv.GetNodeInfo.assert_called_with(hvparams=hvparams)
754
755
756 class TestApplyStorageInfoFunction(unittest.TestCase):
757
758 _STORAGE_KEY = "some_key"
759 _SOME_ARGS = ["some_args"]
760
761 def setUp(self):
762 self.mock_storage_fn = mock.Mock()
763
764 def testApplyValidStorageType(self):
765 storage_type = constants.ST_LVM_VG
766 info_fn_orig = backend._STORAGE_TYPE_INFO_FN
767 backend._STORAGE_TYPE_INFO_FN = {
768 storage_type: self.mock_storage_fn
769 }
770
771 backend._ApplyStorageInfoFunction(
772 storage_type, self._STORAGE_KEY, self._SOME_ARGS)
773
774 self.mock_storage_fn.assert_called_with(self._STORAGE_KEY, self._SOME_ARGS)
775 backend._STORAGE_TYPE_INFO_FN = info_fn_orig
776
777 def testApplyInValidStorageType(self):
778 storage_type = "invalid_storage_type"
779 info_fn_orig = backend._STORAGE_TYPE_INFO_FN
780 backend._STORAGE_TYPE_INFO_FN = {}
781
782 self.assertRaises(KeyError, backend._ApplyStorageInfoFunction,
783 storage_type, self._STORAGE_KEY, self._SOME_ARGS)
784 backend._STORAGE_TYPE_INFO_FN = info_fn_orig
785
786 def testApplyNotImplementedStorageType(self):
787 storage_type = "not_implemented_storage_type"
788 info_fn_orig = backend._STORAGE_TYPE_INFO_FN
789 backend._STORAGE_TYPE_INFO_FN = {storage_type: None}
790
791 self.assertRaises(NotImplementedError,
792 backend._ApplyStorageInfoFunction,
793 storage_type, self._STORAGE_KEY, self._SOME_ARGS)
794 backend._STORAGE_TYPE_INFO_FN = info_fn_orig
795
796
797 class TestGetLvmVgSpaceInfo(unittest.TestCase):
798
799 def testValid(self):
800 path = "somepath"
801 excl_stor = True
802 orig_fn = backend._GetVgInfo
803 backend._GetVgInfo = mock.Mock()
804 backend._GetLvmVgSpaceInfo(path, [excl_stor])
805 backend._GetVgInfo.assert_called_with(path, excl_stor)
806 backend._GetVgInfo = orig_fn
807
808 def testNoExclStorageNotBool(self):
809 path = "somepath"
810 excl_stor = "123"
811 self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
812 path, [excl_stor])
813
814 def testNoExclStorageNotInList(self):
815 path = "somepath"
816 excl_stor = "123"
817 self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
818 path, excl_stor)
819
820 class TestGetLvmPvSpaceInfo(unittest.TestCase):
821
822 def testValid(self):
823 path = "somepath"
824 excl_stor = True
825 orig_fn = backend._GetVgSpindlesInfo
826 backend._GetVgSpindlesInfo = mock.Mock()
827 backend._GetLvmPvSpaceInfo(path, [excl_stor])
828 backend._GetVgSpindlesInfo.assert_called_with(path, excl_stor)
829 backend._GetVgSpindlesInfo = orig_fn
830
831
832 class TestCheckStorageParams(unittest.TestCase):
833
834 def testParamsNone(self):
835 self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
836 None, NotImplemented)
837
838 def testParamsWrongType(self):
839 self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
840 "string", NotImplemented)
841
842 def testParamsEmpty(self):
843 backend._CheckStorageParams([], 0)
844
845 def testParamsValidNumber(self):
846 backend._CheckStorageParams(["a", True], 2)
847
848 def testParamsInvalidNumber(self):
849 self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
850 ["b", False], 3)
851
852
853 class TestGetVgSpindlesInfo(unittest.TestCase):
854
855 def setUp(self):
856 self.vg_free = 13
857 self.vg_size = 31
858 self.mock_fn = mock.Mock(return_value=(self.vg_free, self.vg_size))
859
860 def testValidInput(self):
861 name = "myvg"
862 excl_stor = True
863 result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
864 self.mock_fn.assert_called_with(name)
865 self.assertEqual(name, result["name"])
866 self.assertEqual(constants.ST_LVM_PV, result["type"])
867 self.assertEqual(self.vg_free, result["storage_free"])
868 self.assertEqual(self.vg_size, result["storage_size"])
869
870 def testNoExclStor(self):
871 name = "myvg"
872 excl_stor = False
873 result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
874 self.mock_fn.assert_not_called()
875 self.assertEqual(name, result["name"])
876 self.assertEqual(constants.ST_LVM_PV, result["type"])
877 self.assertEqual(0, result["storage_free"])
878 self.assertEqual(0, result["storage_size"])
879
880
881 class TestGetVgSpindlesInfo(unittest.TestCase):
882
883 def testValidInput(self):
884 self.vg_free = 13
885 self.vg_size = 31
886 self.mock_fn = mock.Mock(return_value=[(self.vg_free, self.vg_size)])
887 name = "myvg"
888 excl_stor = True
889 result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
890 self.mock_fn.assert_called_with([name], excl_stor)
891 self.assertEqual(name, result["name"])
892 self.assertEqual(constants.ST_LVM_VG, result["type"])
893 self.assertEqual(self.vg_free, result["storage_free"])
894 self.assertEqual(self.vg_size, result["storage_size"])
895
896 def testNoExclStor(self):
897 name = "myvg"
898 excl_stor = True
899 self.mock_fn = mock.Mock(return_value=None)
900 result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
901 self.mock_fn.assert_called_with([name], excl_stor)
902 self.assertEqual(name, result["name"])
903 self.assertEqual(constants.ST_LVM_VG, result["type"])
904 self.assertEqual(None, result["storage_free"])
905 self.assertEqual(None, result["storage_size"])
906
907
908 class TestGetNodeInfo(unittest.TestCase):
909
910 _SOME_RESULT = None
911
912 def testApplyStorageInfoFunction(self):
913 orig_fn = backend._ApplyStorageInfoFunction
914 backend._ApplyStorageInfoFunction = mock.Mock(
915 return_value=self._SOME_RESULT)
916 storage_units = [(st, st + "_key", [st + "_params"]) for st in
917 constants.STORAGE_TYPES]
918
919 backend.GetNodeInfo(storage_units, None)
920
921 call_args_list = backend._ApplyStorageInfoFunction.call_args_list
922 self.assertEqual(len(constants.STORAGE_TYPES), len(call_args_list))
923 for call in call_args_list:
924 storage_type, storage_key, storage_params = call[0]
925 self.assertEqual(storage_type + "_key", storage_key)
926 self.assertEqual([storage_type + "_params"], storage_params)
927 self.assertTrue(storage_type in constants.STORAGE_TYPES)
928 backend._ApplyStorageInfoFunction = orig_fn
929
930
931 class TestSpaceReportingConstants(unittest.TestCase):
932 """Ensures consistency between STS_REPORT and backend.
933
934 These tests ensure, that the constant 'STS_REPORT' is consistent
935 with the implementation of invoking space reporting functions
936 in backend.py. Once space reporting is available for all types,
937 the constant can be removed and these tests as well.
938
939 """
940
941 REPORTING = set(constants.STS_REPORT)
942 NOT_REPORTING = set(constants.STORAGE_TYPES) - REPORTING
943
944 def testAllReportingTypesHaveAReportingFunction(self):
945 for storage_type in TestSpaceReportingConstants.REPORTING:
946 self.assertTrue(backend._STORAGE_TYPE_INFO_FN[storage_type] is not None)
947
948 def testAllNotReportingTypesDontHaveFunction(self):
949 for storage_type in TestSpaceReportingConstants.NOT_REPORTING:
950 self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type])
951
952
953 class TestAddRemoveGenerateNodeSshKey(testutils.GanetiTestCase):
954
955 _CLUSTER_NAME = "mycluster"
956 _SSH_PORT = 22
957
958 def setUp(self):
959 testutils.GanetiTestCase.setUp(self)
960 self._ssh_add_authorized_patcher = testutils \
961 .patch_object(ssh, "AddAuthorizedKeys")
962 self._ssh_remove_authorized_patcher = testutils \
963 .patch_object(ssh, "RemoveAuthorizedKeys")
964 self._ssh_add_authorized_mock = self._ssh_add_authorized_patcher.start()
965
966 self._ssconf_mock = mock.Mock()
967 self._ssconf_mock.GetNodeList = mock.Mock()
968 self._ssconf_mock.GetMasterNode = mock.Mock()
969 self._ssconf_mock.GetClusterName = mock.Mock()
970
971 self._run_cmd_mock = mock.Mock()
972
973 self._ssh_remove_authorized_mock = \
974 self._ssh_remove_authorized_patcher.start()
975 self.noded_cert_file = testutils.TestDataFilename("cert1.pem")
976
977 def tearDown(self):
978 super(testutils.GanetiTestCase, self).tearDown()
979 self._ssh_add_authorized_patcher.stop()
980 self._ssh_remove_authorized_patcher.stop()
981
982 def _SetupTestData(self, number_of_nodes=15, number_of_pot_mcs=5,
983 number_of_mcs=5):
984 """Sets up consistent test data for a cluster with a couple of nodes.
985
986 """
987 self._pub_key_file = self._CreateTempFile()
988 self._all_nodes = []
989 self._potential_master_candidates = []
990 self._master_candidate_uuids = []
991 self._ssh_port_map = {}
992
993 self._ssconf_mock.reset_mock()
994 self._ssconf_mock.GetNodeList.reset_mock()
995 self._ssconf_mock.GetMasterNode.reset_mock()
996 self._ssconf_mock.GetClusterName.reset_mock()
997 self._run_cmd_mock.reset_mock()
998
999 for i in range(number_of_nodes):
1000 node_name = "node_name_%s" % i
1001 node_uuid = "node_uuid_%s" % i
1002 self._ssh_port_map[node_name] = self._SSH_PORT
1003 self._all_nodes.append(node_name)
1004
1005 if i in range(number_of_mcs + number_of_pot_mcs):
1006 ssh.AddPublicKey("node_uuid_%s" % i, "key%s" % i,
1007 key_file=self._pub_key_file)
1008 self._potential_master_candidates.append(node_name)
1009
1010 if i in range(number_of_mcs):
1011 self._master_candidate_uuids.append(node_uuid)
1012
1013 self._master_node = "node_name_%s" % (number_of_mcs / 2)
1014 self._ssconf_mock.GetNodeList.return_value = self._all_nodes
1015
1016 def _TearDownTestData(self):
1017 os.remove(self._pub_key_file)
1018
1019 def _KeyOperationExecuted(self, key_data, node_name, expected_type,
1020 expected_key, action_types):
1021 if not node_name in key_data:
1022 return False
1023 for data in key_data[node_name]:
1024 if expected_type in data:
1025 (action, key_dict) = data[expected_type]
1026 if action in action_types:
1027 for key_list in key_dict.values():
1028 if expected_key in key_list:
1029 return True
1030 return False
1031
1032 def _KeyReceived(self, key_data, node_name, expected_type,
1033 expected_key):
1034 return self._KeyOperationExecuted(
1035 key_data, node_name, expected_type, expected_key,
1036 [constants.SSHS_ADD, constants.SSHS_OVERRIDE,
1037 constants.SSHS_REPLACE_OR_ADD])
1038
1039 def _KeyRemoved(self, key_data, node_name, expected_type,
1040 expected_key):
1041 if self._KeyOperationExecuted(
1042 key_data, node_name, expected_type, expected_key,
1043 [constants.SSHS_REMOVE]):
1044 return True
1045 else:
1046 if not node_name in key_data:
1047 return False
1048 for data in key_data[node_name]:
1049 if expected_type in data:
1050 (action, key_dict) = data[expected_type]
1051 if action == constants.SSHS_CLEAR:
1052 return True
1053 return False
1054
1055 def _GetCallsPerNode(self):
1056 calls_per_node = {}
1057 for (pos, keyword) in self._run_cmd_mock.call_args_list:
1058 (cluster_name, node, _, _, _, _, _, _, _, data, _) = pos
1059 if not node in calls_per_node:
1060 calls_per_node[node] = []
1061 calls_per_node[node].append(data)
1062 return calls_per_node
1063
1064 def testGenerateKey(self):
1065 test_node_name = "node_name_7"
1066 test_node_uuid = "node_uuid_7"
1067
1068 self._SetupTestData()
1069
1070 backend._GenerateNodeSshKey(test_node_uuid, test_node_name,
1071 self._ssh_port_map,
1072 pub_key_file=self._pub_key_file,
1073 ssconf_store=self._ssconf_mock,
1074 noded_cert_file=self.noded_cert_file,
1075 run_cmd_fn=self._run_cmd_mock)
1076
1077 calls_per_node = self._GetCallsPerNode()
1078 for node, calls in calls_per_node.items():
1079 self.assertEquals(node, test_node_name)
1080 for call in calls:
1081 self.assertTrue(constants.SSHS_GENERATE in call)
1082
1083 def testAddNodeSshKeyValid(self):
1084 new_node_name = "new_node_name"
1085 new_node_uuid = "new_node_uuid"
1086 new_node_key1 = "new_node_key1"
1087 new_node_key2 = "new_node_key2"
1088
1089 for (to_authorized_keys, to_public_keys, get_public_keys) in \
1090 [(True, True, False), (False, True, False),
1091 (True, True, True), (False, True, True)]:
1092
1093 self._SetupTestData()
1094
1095 # set up public key file, ssconf store, and node lists
1096 if to_public_keys:
1097 for key in [new_node_key1, new_node_key2]:
1098 ssh.AddPublicKey(new_node_name, key, key_file=self._pub_key_file)
1099 self._potential_master_candidates.append(new_node_name)
1100
1101 self._ssh_port_map[new_node_name] = self._SSH_PORT
1102
1103 backend.AddNodeSshKey(new_node_uuid, new_node_name,
1104 to_authorized_keys,
1105 to_public_keys,
1106 get_public_keys,
1107 self._ssh_port_map,
1108 self._potential_master_candidates,
1109 pub_key_file=self._pub_key_file,
1110 ssconf_store=self._ssconf_mock,
1111 noded_cert_file=self.noded_cert_file,
1112 run_cmd_fn=self._run_cmd_mock)
1113
1114 calls_per_node = self._GetCallsPerNode()
1115
1116 # one sample node per type (master candidate, potential master candidate,
1117 # normal node)
1118 mc_idx = 3
1119 pot_mc_idx = 7
1120 normal_idx = 12
1121 sample_nodes = [mc_idx, pot_mc_idx, normal_idx]
1122 pot_sample_nodes = [mc_idx, pot_mc_idx]
1123
1124 if to_authorized_keys:
1125 for node_idx in sample_nodes:
1126 self.assertTrue(self._KeyReceived(
1127 calls_per_node, "node_name_%i" % node_idx,
1128 constants.SSHS_SSH_AUTHORIZED_KEYS, new_node_key1),
1129 "Node %i did not receive authorized key '%s' although it should"
1130 " have." % (node_idx, new_node_key1))
1131 else:
1132 for node_idx in sample_nodes:
1133 self.assertFalse(self._KeyReceived(
1134 calls_per_node, "node_name_%i" % node_idx,
1135 constants.SSHS_SSH_AUTHORIZED_KEYS, new_node_key1),
1136 "Node %i received authorized key '%s', although it should not have."
1137 % (node_idx, new_node_key1))
1138
1139 if to_public_keys:
1140 for node_idx in pot_sample_nodes:
1141 self.assertTrue(self._KeyReceived(
1142 calls_per_node, "node_name_%i" % node_idx,
1143 constants.SSHS_SSH_PUBLIC_KEYS, new_node_key1),
1144 "Node %i did not receive public key '%s', although it should have."
1145 % (node_idx, new_node_key1))
1146 else:
1147 for node_idx in sample_nodes:
1148 self.assertFalse(self._KeyReceived(
1149 calls_per_node, "node_name_%i" % node_idx,
1150 constants.SSHS_SSH_PUBLIC_KEYS, new_node_key1),
1151 "Node %i did receive public key '%s', although it should have."
1152 % (node_idx, new_node_key1))
1153
1154 if get_public_keys:
1155 for node_idx in sample_nodes:
1156 if node_idx in pot_sample_nodes:
1157 self.assertTrue(self._KeyReceived(
1158 calls_per_node, new_node_name,
1159 constants.SSHS_SSH_PUBLIC_KEYS, "key%s" % node_idx),
1160 "The new node '%s' did not receive public key of node %i,"
1161 " although it should have." %
1162 (new_node_name, node_idx))
1163 else:
1164 self.assertFalse(self._KeyReceived(
1165 calls_per_node, new_node_name,
1166 constants.SSHS_SSH_PUBLIC_KEYS, "key%s" % node_idx),
1167 "The new node '%s' did receive public key of node %i,"
1168 " although it should not have." %
1169 (new_node_name, node_idx))
1170 else:
1171 new_node_name not in calls_per_node
1172
1173 self._TearDownTestData()
1174
1175 def testRemoveNodeSshKeyValid(self):
1176 node_name = "node_name"
1177 node_uuid = "node_uuid"
1178 node_key1 = "node_key1"
1179 node_key2 = "node_key2"
1180
1181 for (from_authorized_keys, from_public_keys,
1182 clear_authorized_keys) in \
1183 [(True, True, False),
1184 (True, False, False),
1185 (False, True, False),
1186 (False, True, True),
1187 (False, False, True),
1188 (True, True, True),
1189 ]:
1190
1191 self._SetupTestData()
1192
1193 # set up public key file, ssconf store, and node lists
1194 if from_public_keys or from_authorized_keys:
1195 for key in [node_key1, node_key2]:
1196 ssh.AddPublicKey(node_uuid, key, key_file=self._pub_key_file)
1197 self._potential_master_candidates.append(node_name)
1198 if from_authorized_keys:
1199 ssh.AddAuthorizedKeys(self._pub_key_file, [node_key1, node_key2])
1200
1201 self._ssh_port_map[node_name] = self._SSH_PORT
1202
1203 if from_authorized_keys:
1204 self._master_candidate_uuids.append(node_uuid)
1205
1206 backend.RemoveNodeSshKey(node_uuid, node_name,
1207 from_authorized_keys,
1208 from_public_keys,
1209 clear_authorized_keys,
1210 self._ssh_port_map,
1211 self._master_candidate_uuids,
1212 self._potential_master_candidates,
1213 pub_key_file=self._pub_key_file,
1214 ssconf_store=self._ssconf_mock,
1215 noded_cert_file=self.noded_cert_file,
1216 run_cmd_fn=self._run_cmd_mock)
1217
1218 calls_per_node = self._GetCallsPerNode()
1219
1220 # one sample node per type (master candidate, potential master candidate,
1221 # normal node)
1222 mc_idx = 3
1223 pot_mc_idx = 7
1224 normal_idx = 12
1225 sample_nodes = [mc_idx, pot_mc_idx, normal_idx]
1226 pot_sample_nodes = [mc_idx, pot_mc_idx]
1227
1228 if from_authorized_keys:
1229 for node_idx in sample_nodes:
1230 self.assertTrue(self._KeyRemoved(
1231 calls_per_node, "node_name_%i" % node_idx,
1232 constants.SSHS_SSH_AUTHORIZED_KEYS, node_key1),
1233 "Node %i did not get request to remove authorized key '%s'"
1234 " although it should have." % (node_idx, node_key1))
1235 else:
1236 for node_idx in sample_nodes:
1237 self.assertFalse(self._KeyRemoved(
1238 calls_per_node, "node_name_%i" % node_idx,
1239 constants.SSHS_SSH_AUTHORIZED_KEYS, node_key1),
1240 "Node %i got requested to remove authorized key '%s', although it"
1241 " should not have." % (node_idx, node_key1))
1242
1243 if from_public_keys:
1244 for node_idx in pot_sample_nodes:
1245 self.assertTrue(self._KeyRemoved(
1246 calls_per_node, "node_name_%i" % node_idx,
1247 constants.SSHS_SSH_PUBLIC_KEYS, node_key1),
1248 "Node %i did not receive request to remove public key '%s',"
1249 " although it should have." % (node_idx, node_key1))
1250 self.assertTrue(self._KeyRemoved(
1251 calls_per_node, node_name,
1252 constants.SSHS_SSH_PUBLIC_KEYS, node_key1),
1253 "Node %s did not receive request to remove its own public key '%s',"
1254 " although it should have." % (node_name, node_key1))
1255 for node_idx in list(set(sample_nodes) - set(pot_sample_nodes)):
1256 self.assertFalse(self._KeyRemoved(
1257 calls_per_node, "node_name_%i" % node_idx,
1258 constants.SSHS_SSH_PUBLIC_KEYS, node_key1),
1259 "Node %i received a request to remove public key '%s',"
1260 " although it should not have." % (node_idx, node_key1))
1261 else:
1262 for node_idx in sample_nodes:
1263 self.assertFalse(self._KeyRemoved(
1264 calls_per_node, "node_name_%i" % node_idx,
1265 constants.SSHS_SSH_PUBLIC_KEYS, node_key1),
1266 "Node %i received a request to remove public key '%s',"
1267 " although it should not have." % (node_idx, node_key1))
1268
1269 if clear_authorized_keys:
1270 for node_idx in list(set(sample_nodes) - set([mc_idx])):
1271 key = "key%s" % node_idx
1272 self.assertFalse(self._KeyRemoved(
1273 calls_per_node, node_name,
1274 constants.SSHS_SSH_AUTHORIZED_KEYS, key),
1275 "Node %s did receive request to remove authorized key '%s',"
1276 " although it should not have." % (node_name, key))
1277 mc_key = "key%s" % mc_idx
1278 self.assertTrue(self._KeyRemoved(
1279 calls_per_node, node_name,
1280 constants.SSHS_SSH_AUTHORIZED_KEYS, mc_key),
1281 "Node %s did not receive request to remove authorized key '%s',"
1282 " although it should have." % (node_name, mc_key))
1283 else:
1284 for node_idx in sample_nodes:
1285 key = "key%s" % node_idx
1286 self.assertFalse(self._KeyRemoved(
1287 calls_per_node, node_name,
1288 constants.SSHS_SSH_AUTHORIZED_KEYS, key),
1289 "Node %s did receive request to remove authorized key '%s',"
1290 " although it should not have." % (node_name, key))
1291
1292
1293 class TestVerifySshSetup(testutils.GanetiTestCase):
1294
1295 _NODE1_UUID = "uuid1"
1296 _NODE2_UUID = "uuid2"
1297 _NODE3_UUID = "uuid3"
1298 _NODE1_NAME = "name1"
1299 _NODE2_NAME = "name2"
1300 _NODE3_NAME = "name3"
1301 _NODE1_KEYS = ["key11", "key12"]
1302 _NODE2_KEYS = ["key21"]
1303 _NODE3_KEYS = ["key31"]
1304
1305 _NODE_STATUS_LIST = [
1306 (_NODE1_UUID, _NODE1_NAME, True, True),
1307 (_NODE2_UUID, _NODE2_NAME, False, True),
1308 (_NODE3_UUID, _NODE3_NAME, False, False),
1309 ]
1310
1311 _PUB_KEY_RESULT = {
1312 _NODE1_UUID: _NODE1_KEYS,
1313 _NODE2_UUID: _NODE2_KEYS,
1314 _NODE3_UUID: _NODE3_KEYS,
1315 }
1316
1317 _AUTH_RESULT = {
1318 _NODE1_KEYS[0]: True,
1319 _NODE1_KEYS[1]: True,
1320 _NODE2_KEYS[0]: False,
1321 _NODE3_KEYS[0]: False,
1322 }
1323
1324 def setUp(self):
1325 testutils.GanetiTestCase.setUp(self)
1326 self._has_authorized_patcher = testutils \
1327 .patch_object(ssh, "HasAuthorizedKey")
1328 self._has_authorized_mock = self._has_authorized_patcher.start()
1329 self._query_patcher = testutils \
1330 .patch_object(ssh, "QueryPubKeyFile")
1331 self._query_mock = self._query_patcher.start()
1332 self._read_file_patcher = testutils \
1333 .patch_object(utils, "ReadFile")
1334 self._read_file_mock = self._read_file_patcher.start()
1335 self._read_file_mock.return_value = self._NODE1_KEYS[0]
1336
1337 def tearDown(self):
1338 super(testutils.GanetiTestCase, self).tearDown()
1339 self._has_authorized_patcher.stop()
1340 self._query_patcher.stop()
1341 self._read_file_patcher.stop()
1342
1343 def testValidData(self):
1344 self._has_authorized_mock.side_effect = \
1345 lambda _, key : self._AUTH_RESULT[key]
1346 self._query_mock.return_value = self._PUB_KEY_RESULT
1347 result = backend._VerifySshSetup(self._NODE_STATUS_LIST,
1348 self._NODE1_NAME)
1349 self.assertEqual(result, [])
1350
1351 def testMissingKey(self):
1352 self._has_authorized_mock.side_effect = \
1353 lambda _, key : self._AUTH_RESULT[key]
1354 pub_key_missing = copy.deepcopy(self._PUB_KEY_RESULT)
1355 del pub_key_missing[self._NODE2_UUID]
1356 self._query_mock.return_value = pub_key_missing
1357 result = backend._VerifySshSetup(self._NODE_STATUS_LIST,
1358 self._NODE1_NAME)
1359 self.assertTrue(self._NODE2_UUID in result[0])
1360
1361 def testUnknownKey(self):
1362 self._has_authorized_mock.side_effect = \
1363 lambda _, key : self._AUTH_RESULT[key]
1364 pub_key_missing = copy.deepcopy(self._PUB_KEY_RESULT)
1365 pub_key_missing["unkownnodeuuid"] = "pinkbunny"
1366 self._query_mock.return_value = pub_key_missing
1367 result = backend._VerifySshSetup(self._NODE_STATUS_LIST,
1368 self._NODE1_NAME)
1369 self.assertTrue("unkownnodeuuid" in result[0])
1370
1371 def testMissingMasterCandidate(self):
1372 auth_result = copy.deepcopy(self._AUTH_RESULT)
1373 auth_result["key12"] = False
1374 self._has_authorized_mock.side_effect = \
1375 lambda _, key : auth_result[key]
1376 self._query_mock.return_value = self._PUB_KEY_RESULT
1377 result = backend._VerifySshSetup(self._NODE_STATUS_LIST,
1378 self._NODE1_NAME)
1379 self.assertTrue(self._NODE1_UUID in result[0])
1380
1381 def testSuperfluousNormalNode(self):
1382 auth_result = copy.deepcopy(self._AUTH_RESULT)
1383 auth_result["key31"] = True
1384 self._has_authorized_mock.side_effect = \
1385 lambda _, key : auth_result[key]
1386 self._query_mock.return_value = self._PUB_KEY_RESULT
1387 result = backend._VerifySshSetup(self._NODE_STATUS_LIST,
1388 self._NODE1_NAME)
1389 self.assertTrue(self._NODE3_UUID in result[0])
1390
1391
1392 if __name__ == "__main__":
1393 testutils.GanetiTestProgram()