Make node_daemon_setup use common functions
authorHelga Velroyen <helgav@google.com>
Fri, 17 Jul 2015 07:44:12 +0000 (09:44 +0200)
committerHelga Velroyen <helgav@google.com>
Fri, 17 Jul 2015 11:42:05 +0000 (13:42 +0200)
This patch makes the node_daemon_setup tool use some of
the recently introduced functions in the tools/common.py.
By doing that, this also cleans up the correct usage of
cluster name constants.

Signed-off-by: Helga Velroyen <helgav@google.com>
Reviewed-by: Klaus Aehlig <aehlig@google.com>

lib/tools/common.py
lib/tools/node_daemon_setup.py
lib/tools/prepare_node_join.py
lib/tools/ssh_update.py
lib/tools/ssl_update.py
test/py/ganeti.tools.common_unittest.py
test/py/ganeti.tools.node_daemon_setup_unittest.py

index 9478655..a9149f6 100644 (file)
@@ -166,19 +166,21 @@ def VerifyCertificateStrong(data, error_fn,
   return _verify_fn(cert, error_fn)
 
 
-def VerifyClusterName(data, error_fn,
+def VerifyClusterName(data, error_fn, cluster_name_constant,
                       _verify_fn=ssconf.VerifyClusterName):
   """Verifies cluster name.
 
   @type data: dict
 
   """
-  name = data.get(constants.SSHS_CLUSTER_NAME)
+  name = data.get(cluster_name_constant)
   if name:
     _verify_fn(name)
   else:
     raise error_fn("Cluster name must be specified")
 
+  return name
+
 
 def LoadData(raw, data_check):
   """Parses and verifies input data.
index 89e8a18..e45e2e0 100644 (file)
@@ -36,15 +36,12 @@ import os.path
 import optparse
 import sys
 import logging
-import OpenSSL
-from cStringIO import StringIO
 
 from ganeti import cli
 from ganeti import constants
 from ganeti import errors
 from ganeti import pathutils
 from ganeti import utils
-from ganeti import serializer
 from ganeti import runtime
 from ganeti import ht
 from ganeti import ssconf
@@ -93,87 +90,6 @@ def VerifyOptions(parser, opts, args):
   return opts
 
 
-def _VerifyCertificate(cert_pem, _check_fn=utils.CheckNodeCertificate):
-  """Verifies a certificate against the local node daemon certificate.
-
-  @type cert_pem: string
-  @param cert_pem: Certificate and key in PEM format
-  @rtype: string
-  @return: Formatted key and certificate
-
-  """
-  try:
-    cert = \
-      OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
-  except Exception, err:
-    raise errors.X509CertError("(stdin)",
-                               "Unable to load certificate: %s" % err)
-
-  try:
-    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
-  except OpenSSL.crypto.Error, err:
-    raise errors.X509CertError("(stdin)",
-                               "Unable to load private key: %s" % err)
-
-  # Check certificate with given key; this detects cases where the key given on
-  # stdin doesn't match the certificate also given on stdin
-  try:
-    utils.X509CertKeyCheck(cert, key)
-  except OpenSSL.SSL.Error:
-    raise errors.X509CertError("(stdin)",
-                               "Certificate is not signed with given key")
-
-  # Standard checks, including check against an existing local certificate
-  # (no-op if that doesn't exist)
-  _check_fn(cert)
-
-  key_encoded = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
-  cert_encoded = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
-                                                 cert)
-  complete_cert_encoded = key_encoded + cert_encoded
-  if not cert_pem == complete_cert_encoded:
-    logging.error("The certificate differs after being reencoded. Please"
-                  " renew the certificates cluster-wide to prevent future"
-                  " inconsistencies.")
-
-  # Format for storing on disk
-  buf = StringIO()
-  buf.write(cert_pem)
-  return buf.getvalue()
-
-
-def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
-  """Verifies cluster certificate.
-
-  @type data: dict
-  @rtype: string
-  @return: Formatted key and certificate
-
-  """
-  cert = data.get(constants.NDS_NODE_DAEMON_CERTIFICATE)
-  if not cert:
-    raise SetupError("Node daemon certificate must be specified")
-
-  return _verify_fn(cert)
-
-
-def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
-  """Verifies cluster name.
-
-  @type data: dict
-  @rtype: string
-  @return: Cluster name
-
-  """
-  name = data.get(constants.NDS_CLUSTER_NAME)
-  if not name:
-    raise SetupError("Cluster name must be specified")
-
-  _verify_fn(name)
-
-  return name
-
-
 def VerifySsconf(data, cluster_name, _verify_fn=ssconf.VerifyKeys):
   """Verifies ssconf names.
 
@@ -195,15 +111,6 @@ def VerifySsconf(data, cluster_name, _verify_fn=ssconf.VerifyKeys):
   return items
 
 
-def LoadData(raw):
-  """Parses and verifies input data.
-
-  @rtype: dict
-
-  """
-  return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
-
-
 def Main():
   """Main routine.
 
@@ -215,10 +122,11 @@ def Main():
   try:
     getent = runtime.GetEnts()
 
-    data = LoadData(sys.stdin.read())
+    data = common.LoadData(sys.stdin.read(), SetupError)
 
-    cluster_name = VerifyClusterName(data)
-    cert_pem = VerifyCertificate(data)
+    cluster_name = common.VerifyClusterName(data, SetupError,
+                                            constants.NDS_CLUSTER_NAME)
+    cert_pem = common.VerifyCertificateStrong(data, SetupError)
     ssdata = VerifySsconf(data, cluster_name)
 
     logging.info("Writing ssconf files ...")
index 0902cf4..82a35dc 100644 (file)
@@ -196,7 +196,7 @@ def Main():
     data = common.LoadData(sys.stdin.read(), _DATA_CHECK)
 
     # Check if input data is correct
-    common.VerifyClusterName(data, JoinError)
+    common.VerifyClusterName(data, JoinError, constants.SSHS_CLUSTER_NAME)
     common.VerifyCertificateSoft(data, JoinError)
 
     # Update SSH files
index 904cbd3..f9d1b6d 100644 (file)
@@ -209,7 +209,7 @@ def Main():
     data = common.LoadData(sys.stdin.read(), _DATA_CHECK)
 
     # Check if input data is correct
-    common.VerifyClusterName(data, SshUpdateError)
+    common.VerifyClusterName(data, SshUpdateError, constants.SSHS_CLUSTER_NAME)
     common.VerifyCertificateSoft(data, SshUpdateError)
 
     # Update / Generate SSH files
index f9c5c19..56e8d6a 100644 (file)
@@ -119,7 +119,7 @@ def Main():
   try:
     data = common.LoadData(sys.stdin.read(), _DATA_CHECK)
 
-    common.VerifyClusterName(data, SslSetupError)
+    common.VerifyClusterName(data, SslSetupError, constants.NDS_CLUSTER_NAME)
 
     # Verifies whether the server certificate of the caller
     # is the same as on this node.
index 427b851..0eb7e45 100755 (executable)
@@ -115,7 +115,8 @@ class TestVerifyClusterName(unittest.TestCase):
 
   def testNoName(self):
     self.assertRaises(self.MyException, common.VerifyClusterName,
-                      {}, self.MyException, _verify_fn=NotImplemented)
+                      {}, self.MyException, "cluster_name",
+                      _verify_fn=NotImplemented)
 
   @staticmethod
   def _FailingVerify(name):
@@ -128,7 +129,87 @@ class TestVerifyClusterName(unittest.TestCase):
       }
 
     self.assertRaises(errors.GenericError, common.VerifyClusterName,
-                      data, Exception, _verify_fn=self._FailingVerify)
+                      data, self.MyException, "cluster_name",
+                      _verify_fn=self._FailingVerify)
+
+
+class TestVerifyCertificateStrong(testutils.GanetiTestCase):
+
+  class MyException(Exception):
+    pass
+
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    testutils.GanetiTestCase.tearDown(self)
+    shutil.rmtree(self.tmpdir)
+
+  def testNoCert(self):
+    self.assertRaises(self.MyException, common.VerifyCertificateStrong,
+                      {}, self.MyException, _verify_fn=NotImplemented)
+
+  def testVerificationSuccessWithCert(self):
+    common.VerifyCertificateStrong({
+      constants.NDS_NODE_DAEMON_CERTIFICATE: "something",
+      }, self.MyException, _verify_fn=lambda x,y: None)
+
+  def testNoPrivateKey(self):
+    cert_filename = testutils.TestDataFilename("cert1.pem")
+    cert_pem = utils.ReadFile(cert_filename)
+
+    self.assertRaises(self.MyException,
+                      common._VerifyCertificateStrong,
+                      cert_pem, self.MyException, _check_fn=NotImplemented)
+
+  def testInvalidCertificate(self):
+    self.assertRaises(self.MyException,
+                      common._VerifyCertificateStrong,
+                      "Something that's not a certificate",
+                      self.MyException,
+                      _check_fn=NotImplemented)
+
+  @staticmethod
+  def _Check(cert):
+    assert cert.get_subject()
+
+  def testSuccessfulCheck(self):
+    cert_filename = testutils.TestDataFilename("cert2.pem")
+    cert_pem = utils.ReadFile(cert_filename)
+    result = \
+      common._VerifyCertificateStrong(cert_pem, self.MyException,
+                                      _check_fn=self._Check)
+
+    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, result)
+    self.assertTrue(cert)
+
+    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, result)
+    self.assertTrue(key)
+
+  def testMismatchingKey(self):
+    cert1_path = testutils.TestDataFilename("cert1.pem")
+    cert2_path = testutils.TestDataFilename("cert2.pem")
+
+    # Extract certificate
+    cert1 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
+                                            utils.ReadFile(cert1_path))
+    cert1_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
+                                                cert1)
+
+    # Extract mismatching key
+    key2 = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
+                                          utils.ReadFile(cert2_path))
+    key2_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM,
+                                              key2)
+
+    try:
+      common._VerifyCertificateStrong(cert1_pem + key2_pem, self.MyException,
+                                      _check_fn=NotImplemented)
+    except self.MyException, err:
+      self.assertTrue("not signed with given key" in str(err))
+    else:
+      self.fail("Exception was not raised")
 
 
 if __name__ == "__main__":
index a9fd1a9..9a3abdf 100755 (executable)
 """Script for testing ganeti.tools.node_daemon_setup"""
 
 import unittest
-import shutil
-import tempfile
-import os.path
-import OpenSSL
 
 from ganeti import errors
 from ganeti import constants
-from ganeti import serializer
-from ganeti import pathutils
-from ganeti import compat
-from ganeti import utils
 from ganeti.tools import node_daemon_setup
 
 import testutils
@@ -50,136 +42,6 @@ import testutils
 _SetupError = node_daemon_setup.SetupError
 
 
-class TestLoadData(unittest.TestCase):
-  def testNoJson(self):
-    for data in ["", "{", "}"]:
-      self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, data)
-
-  def testInvalidDataStructure(self):
-    raw = serializer.DumpJson({
-      "some other thing": False,
-      })
-    self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, raw)
-
-    raw = serializer.DumpJson([])
-    self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, raw)
-
-  def testValidData(self):
-    raw = serializer.DumpJson({})
-    self.assertEqual(node_daemon_setup.LoadData(raw), {})
-
-
-class TestVerifyCertificate(testutils.GanetiTestCase):
-  def setUp(self):
-    testutils.GanetiTestCase.setUp(self)
-    self.tmpdir = tempfile.mkdtemp()
-
-  def tearDown(self):
-    testutils.GanetiTestCase.tearDown(self)
-    shutil.rmtree(self.tmpdir)
-
-  def testNoCert(self):
-    self.assertRaises(_SetupError, node_daemon_setup.VerifyCertificate,
-                      {}, _verify_fn=NotImplemented)
-
-  def testVerificationSuccessWithCert(self):
-    node_daemon_setup.VerifyCertificate({
-      constants.NDS_NODE_DAEMON_CERTIFICATE: "something",
-      }, _verify_fn=lambda _: None)
-
-  def testNoPrivateKey(self):
-    cert_filename = testutils.TestDataFilename("cert1.pem")
-    cert_pem = utils.ReadFile(cert_filename)
-
-    self.assertRaises(errors.X509CertError,
-                      node_daemon_setup._VerifyCertificate,
-                      cert_pem, _check_fn=NotImplemented)
-
-  def testInvalidCertificate(self):
-    self.assertRaises(errors.X509CertError,
-                      node_daemon_setup._VerifyCertificate,
-                      "Something that's not a certificate",
-                      _check_fn=NotImplemented)
-
-  @staticmethod
-  def _Check(cert):
-    assert cert.get_subject()
-
-  def testSuccessfulCheck(self):
-    cert_filename = testutils.TestDataFilename("cert2.pem")
-    cert_pem = utils.ReadFile(cert_filename)
-    result = \
-      node_daemon_setup._VerifyCertificate(cert_pem, _check_fn=self._Check)
-
-    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, result)
-    self.assertTrue(cert)
-
-    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, result)
-    self.assertTrue(key)
-
-  def testMismatchingKey(self):
-    cert1_path = testutils.TestDataFilename("cert1.pem")
-    cert2_path = testutils.TestDataFilename("cert2.pem")
-
-    # Extract certificate
-    cert1 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
-                                            utils.ReadFile(cert1_path))
-    cert1_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
-                                                cert1)
-
-    # Extract mismatching key
-    key2 = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
-                                          utils.ReadFile(cert2_path))
-    key2_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM,
-                                              key2)
-
-    try:
-      node_daemon_setup._VerifyCertificate(cert1_pem + key2_pem,
-                                           _check_fn=NotImplemented)
-    except errors.X509CertError, err:
-      self.assertEqual(err.args,
-                       ("(stdin)", "Certificate is not signed with given key"))
-    else:
-      self.fail("Exception was not raised")
-
-
-class TestVerifyClusterName(unittest.TestCase):
-  def setUp(self):
-    unittest.TestCase.setUp(self)
-    self.tmpdir = tempfile.mkdtemp()
-
-  def tearDown(self):
-    unittest.TestCase.tearDown(self)
-    shutil.rmtree(self.tmpdir)
-
-  def testNoName(self):
-    self.assertRaises(_SetupError, node_daemon_setup.VerifyClusterName,
-                      {}, _verify_fn=NotImplemented)
-
-  @staticmethod
-  def _FailingVerify(name):
-    assert name == "somecluster.example.com"
-    raise errors.GenericError()
-
-  def testFailingVerification(self):
-    data = {
-      constants.NDS_CLUSTER_NAME: "somecluster.example.com",
-      }
-
-    self.assertRaises(errors.GenericError, node_daemon_setup.VerifyClusterName,
-                      data, _verify_fn=self._FailingVerify)
-
-  def testSuccess(self):
-    data = {
-      constants.NDS_CLUSTER_NAME: "cluster.example.com",
-      }
-
-    result = \
-      node_daemon_setup.VerifyClusterName(data, _verify_fn=lambda _: None)
-
-    self.assertEqual(result, "cluster.example.com")
-
-
 class TestVerifySsconf(unittest.TestCase):
   def testNoSsconf(self):
     self.assertRaises(_SetupError, node_daemon_setup.VerifySsconf,