Prepare-node-join: use common functions
authorHelga Velroyen <helgav@google.com>
Tue, 16 Jun 2015 15:46:04 +0000 (17:46 +0200)
committerHelga Velroyen <helgav@google.com>
Mon, 6 Jul 2015 10:46:19 +0000 (12:46 +0200)
This patch makes prepare_node_join use some of the functions
that were moved to tools/common.py. The respective unittests
are removed, because they are already tested in
common_unittest.py.

Signed-off-by: Helga Velroyen <helgav@google.com>
Reviewed-by: Klaus Aehlig <aehlig@google.com>

lib/tools/prepare_node_join.py
test/py/ganeti.tools.common_unittest.py
test/py/ganeti.tools.prepare_node_join_unittest.py

index 7eb1e5a..4db335f 100644 (file)
@@ -43,10 +43,9 @@ from ganeti import constants
 from ganeti import errors
 from ganeti import pathutils
 from ganeti import utils
-from ganeti import serializer
 from ganeti import ht
 from ganeti import ssh
-from ganeti import ssconf
+from ganeti.tools import common
 
 
 _SSH_KEY_LIST_ITEM = \
@@ -89,17 +88,7 @@ def ParseOptions():
 
   (opts, args) = parser.parse_args()
 
-  return VerifyOptions(parser, opts, args)
-
-
-def VerifyOptions(parser, opts, args):
-  """Verifies options and arguments for correctness.
-
-  """
-  if args:
-    parser.error("No arguments are expected")
-
-  return opts
+  return common.VerifyOptions(parser, opts, args)
 
 
 def _VerifyCertificate(cert_pem, _check_fn=utils.CheckNodeCertificate):
@@ -137,19 +126,6 @@ def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
     _verify_fn(cert)
 
 
-def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
-  """Verifies cluster name.
-
-  @type data: dict
-
-  """
-  name = data.get(constants.SSHS_CLUSTER_NAME)
-  if name:
-    _verify_fn(name)
-  else:
-    raise JoinError("Cluster name must be specified")
-
-
 def _UpdateKeyFiles(keys, dry_run, keyfiles):
   """Updates SSH key files.
 
@@ -238,15 +214,6 @@ def UpdateSshRoot(data, dry_run, _homedir_fn=None):
       utils.AddAuthorizedKey(auth_keys_file, public_key)
 
 
-def LoadData(raw):
-  """Parses and verifies input data.
-
-  @rtype: dict
-
-  """
-  return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
-
-
 def Main():
   """Main routine.
 
@@ -256,10 +223,10 @@ def Main():
   utils.SetupToolLogging(opts.debug, opts.verbose)
 
   try:
-    data = LoadData(sys.stdin.read())
+    data = common.LoadData(sys.stdin.read(), _DATA_CHECK)
 
     # Check if input data is correct
-    VerifyClusterName(data)
+    common.VerifyClusterName(data, JoinError)
     VerifyCertificate(data)
 
     # Update SSH files
index 1652088..427b851 100755 (executable)
@@ -38,6 +38,8 @@ import OpenSSL
 import time
 
 from ganeti import constants
+from ganeti import errors
+from ganeti import serializer
 from ganeti import utils
 from ganeti.tools import common
 
@@ -78,5 +80,56 @@ class TestGenerateClientCert(unittest.TestCase):
     self.assertEqual(client_cert.get_subject().CN, my_node_name)
 
 
+class TestLoadData(unittest.TestCase):
+
+  def testNoJson(self):
+    self.assertRaises(errors.ParseError, common.LoadData, Exception, "")
+    self.assertRaises(errors.ParseError, common.LoadData, Exception, "}")
+
+  def testInvalidDataStructure(self):
+    raw = serializer.DumpJson({
+      "some other thing": False,
+      })
+    self.assertRaises(errors.ParseError, common.LoadData, Exception, raw)
+
+    raw = serializer.DumpJson([])
+    self.assertRaises(errors.ParseError, common.LoadData, Exception, raw)
+
+  def testValidData(self):
+    raw = serializer.DumpJson({})
+    self.assertEqual(common.LoadData(raw, Exception), {})
+
+
+class TestVerifyClusterName(unittest.TestCase):
+
+  class MyException(Exception):
+    pass
+
+  def setUp(self):
+    unittest.TestCase.setUp(self)
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    unittest.TestCase.tearDown(self)
+    shutil.rmtree(self.tmpdir)
+
+  def testNoName(self):
+    self.assertRaises(self.MyException, common.VerifyClusterName,
+                      {}, self.MyException, _verify_fn=NotImplemented)
+
+  @staticmethod
+  def _FailingVerify(name):
+    assert name == "cluster.example.com"
+    raise errors.GenericError()
+
+  def testFailingVerification(self):
+    data = {
+      constants.SSHS_CLUSTER_NAME: "cluster.example.com",
+      }
+
+    self.assertRaises(errors.GenericError, common.VerifyClusterName,
+                      data, Exception, _verify_fn=self._FailingVerify)
+
+
 if __name__ == "__main__":
   testutils.GanetiTestProgram()
index 09f8750..82acce5 100755 (executable)
@@ -34,11 +34,9 @@ import unittest
 import shutil
 import tempfile
 import os.path
-import OpenSSL
 
 from ganeti import errors
 from ganeti import constants
-from ganeti import serializer
 from ganeti import pathutils
 from ganeti import compat
 from ganeti import utils
@@ -50,25 +48,6 @@ import testutils
 _JoinError = prepare_node_join.JoinError
 
 
-class TestLoadData(unittest.TestCase):
-  def testNoJson(self):
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
-
-  def testInvalidDataStructure(self):
-    raw = serializer.DumpJson({
-      "some other thing": False,
-      })
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
-
-    raw = serializer.DumpJson([])
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
-
-  def testValidData(self):
-    raw = serializer.DumpJson({})
-    self.assertEqual(prepare_node_join.LoadData(raw), {})
-
-
 class TestVerifyCertificate(testutils.GanetiTestCase):
   def setUp(self):
     testutils.GanetiTestCase.setUp(self)
@@ -104,33 +83,6 @@ class TestVerifyCertificate(testutils.GanetiTestCase):
     prepare_node_join._VerifyCertificate(cert_pem, _check_fn=self._Check)
 
 
-class TestVerifyClusterName(unittest.TestCase):
-  def setUp(self):
-    unittest.TestCase.setUp(self)
-    self.tmpdir = tempfile.mkdtemp()
-
-  def tearDown(self):
-    unittest.TestCase.tearDown(self)
-    shutil.rmtree(self.tmpdir)
-
-  def testNoName(self):
-    self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
-                      {}, _verify_fn=NotImplemented)
-
-  @staticmethod
-  def _FailingVerify(name):
-    assert name == "cluster.example.com"
-    raise errors.GenericError()
-
-  def testFailingVerification(self):
-    data = {
-      constants.SSHS_CLUSTER_NAME: "cluster.example.com",
-      }
-
-    self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName,
-                      data, _verify_fn=self._FailingVerify)
-
-
 class TestUpdateSshDaemon(unittest.TestCase):
   def setUp(self):
     unittest.TestCase.setUp(self)