SSH utility functions for key manipulation
authorHelga Velroyen <helgav@google.com>
Fri, 15 Jan 2016 10:18:24 +0000 (11:18 +0100)
committerHelga Velroyen <helgav@google.com>
Fri, 22 Jan 2016 09:39:04 +0000 (10:39 +0100)
So far, the backend code contains a lot of (repetitive)
code to manipulate SSH keys on the local disk. This
patch adds utility functions for those basic operations
and also includes unit tests for those.

In the later patches of this series, those functions
will be used to simplify the code and increase the
code reusage.

Signed-off-by: Helga Velroyen <helgav@google.com>
Reviewed-by: Klaus Aehlig <aehlig@google.com>

lib/ssh.py
test/py/ganeti.ssh_unittest.py

index a8fe86d..2c92264 100644 (file)
@@ -35,6 +35,7 @@
 
 import logging
 import os
+import shutil
 import tempfile
 
 from collections import namedtuple
@@ -1100,6 +1101,153 @@ def ReadRemoteSshPubKeys(pub_key_file, node, cluster_name, port, ask_key,
   return result.stdout
 
 
+def GetSshKeyFilenames(key_type, suffix=""):
+  """Get filenames of the SSH key pair of the given type.
+
+  @type key_type: string
+  @param key_type: type of SSH key, must be element of C{constants.SSHK_ALL}
+  @type suffix: string
+  @param suffix: optional suffix for the key filenames
+  @rtype: tuple of (string, string)
+  @returns: a tuple containing the name of the private key file and the
+       public key file.
+
+  """
+  if key_type not in constants.SSHK_ALL:
+    raise errors.SshUpdateError("Unsupported key type '%s'. Supported key types"
+                                " are: %s." % (key_type, constants.SSHK_ALL))
+  (_, root_keyfiles) = \
+      GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=False, dircheck=False)
+  if not key_type in root_keyfiles.keys():
+    raise errors.SshUpdateError("No keyfile for key type '%s' available."
+                                % key_type)
+
+  key_filenames = root_keyfiles[key_type]
+  if suffix:
+    key_filenames = [_ComputeKeyFilePathWithSuffix(key_filename, suffix)
+                     for key_filename in key_filenames]
+
+  return key_filenames
+
+
+def GetSshPubKeyFilename(key_type, suffix=""):
+  """Get filename of the public SSH key of the given type.
+
+  @type key_type: string
+  @param key_type: type of SSH key, must be element of C{constants.SSHK_ALL}
+  @type suffix: string
+  @param suffix: optional suffix for the key filenames
+  @rtype: string
+  @returns: file name of the public key file
+
+  """
+  return GetSshKeyFilenames(key_type, suffix=suffix)[1]
+
+
+def _ComputeKeyFilePathWithSuffix(key_filepath, suffix):
+  """Converts the given key filename to a key filename with a suffix.
+
+  @type key_filepath: string
+  @param key_filepath: path of the key file
+  @type suffix: string
+  @param suffix: suffix to be appended to the basename of the file
+
+  """
+  path = os.path.dirname(key_filepath)
+  ext = os.path.splitext(os.path.basename(key_filepath))[1]
+  basename = os.path.splitext(os.path.basename(key_filepath))[0]
+  return os.path.join(path, basename + suffix + ext)
+
+
+def ReplaceSshKeys(src_key_type, dest_key_type,
+                   src_key_suffix="", dest_key_suffix=""):
+  """Replaces an SSH key pair by another SSH key pair.
+
+  Note that both parts, the private and the public key, are replaced.
+
+  @type src_key_type: string
+  @param src_key_type: key type of key pair that is replacing the other
+      key pair
+  @type dest_key_type: string
+  @param dest_key_type: key type of the key pair that is being replaced
+      by the source key pair
+  @type src_key_suffix: string
+  @param src_key_suffix: optional suffix of the key files of the source
+      key pair
+  @type dest_key_suffix: string
+  @param dest_key_suffix: optional suffix of the keey files of the
+      destination key pair
+
+  """
+  (src_priv_filename, src_pub_filename) = GetSshKeyFilenames(
+      src_key_type, suffix=src_key_suffix)
+  (dest_priv_filename, dest_pub_filename) = GetSshKeyFilenames(
+      dest_key_type, suffix=dest_key_suffix)
+
+  if not (os.path.exists(src_priv_filename) and
+          os.path.exists(src_pub_filename)):
+    raise errors.SshUpdateError(
+        "At least one of the source key files is missing: %s",
+        ", ".join([src_priv_filename, src_pub_filename]))
+
+  for dest_file in [dest_priv_filename, dest_pub_filename]:
+    if os.path.exists(dest_file):
+      utils.CreateBackup(dest_file)
+      utils.RemoveFile(dest_file)
+
+  shutil.move(src_priv_filename, dest_priv_filename)
+  shutil.move(src_pub_filename, dest_pub_filename)
+
+
+def ReadLocalSshPubKeys(key_types, suffix=""):
+  """Reads the local root user SSH key.
+
+  @type key_types: list of string
+  @param key_types: types of SSH keys. Must be subset of constants.SSHK_ALL. If
+      'None' or [], all available keys are returned.
+  @type suffix: string
+  @param suffix: optional suffix to be attached to key names when reading
+      them. Used for temporary key files.
+  @rtype: list of string
+  @return: list of public keys
+
+  """
+  fetch_key_types = []
+  if key_types:
+    fetch_key_types += key_types
+  else:
+    fetch_key_types = constants.SSHK_ALL
+
+  (_, root_keyfiles) = \
+      GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=False, dircheck=False)
+
+  result_keys = []
+  for (public_key_type, (_, public_key_file)) in root_keyfiles.items():
+
+    if public_key_type not in fetch_key_types:
+      continue
+
+    public_key_dir = os.path.dirname(public_key_file)
+    public_key_filename = ""
+    if suffix:
+      public_key_filename = \
+          os.path.splitext(os.path.basename(public_key_file))[0] \
+          + suffix + ".pub"
+    else:
+      public_key_filename = public_key_file
+    public_key_path = os.path.join(public_key_dir,
+                                   public_key_filename)
+
+    if not os.path.exists(public_key_path):
+      raise errors.SshUpdateError("Cannot find SSH public key of type '%s'."
+                                  % public_key_type)
+    else:
+      key = utils.ReadFile(public_key_path)
+      result_keys.append(key)
+
+  return result_keys
+
+
 # Update gnt-cluster.rst when changing which combinations are valid.
 KeyBitInfo = namedtuple('KeyBitInfo', ['default', 'validation_fn'])
 SSH_KEY_VALID_BITS = {
index 3510992..661245b 100755 (executable)
@@ -521,5 +521,171 @@ class TestDetermineKeyBits(testutils.GanetiTestCase):
       self.assertEquals(b, ssh.DetermineKeyBits("rsa", b, None, None))
 
 
+class TestManageLocalSshPubKeys(testutils.GanetiTestCase):
+  """Test class for several methods handling local SSH keys.
+
+  Methods covered are:
+  - GetSshKeyFilenames
+  - GetSshPubKeyFilename
+  - ReplaceSshKeys
+  - ReadLocalSshPubKeys
+
+  These methods are covered in one test, because the preparations for
+  their tests is identical and thus can be reused.
+
+  """
+  VISIBILITY_PRIVATE = "private"
+  VISIBILITY_PUBLIC = "public"
+  VISIBILITIES = frozenset([VISIBILITY_PRIVATE, VISIBILITY_PUBLIC])
+
+  def _GenerateKey(self, key_id, visibility):
+    assert visibility in self.VISIBILITIES
+    return "I am the %s %s SSH key." % (visibility, key_id)
+
+  def _GetKeyPath(self, key_file_basename):
+     return os.path.join(self.tmpdir, key_file_basename)
+
+  def _SetUpKeys(self):
+    """Creates a fake SSH key for each type and with/without suffix."""
+    self._key_file_dict = {}
+    for key_type in constants.SSHK_ALL:
+      for suffix in ["", self._suffix]:
+        pub_key_filename = "id_%s%s.pub" % (key_type, suffix)
+        priv_key_filename = "id_%s%s" % (key_type, suffix)
+
+        pub_key_path = self._GetKeyPath(pub_key_filename)
+        priv_key_path = self._GetKeyPath(priv_key_filename)
+
+        utils.WriteFile(
+            priv_key_path,
+            data=self._GenerateKey(key_type + suffix, self.VISIBILITY_PRIVATE))
+
+        utils.WriteFile(
+            pub_key_path,
+            data=self._GenerateKey(key_type + suffix, self.VISIBILITY_PUBLIC))
+
+        # Fill key dict only for non-suffix keys
+        # (as this is how it will be in the code)
+        if not suffix:
+          self._key_file_dict[key_type] = \
+            (priv_key_path, pub_key_path)
+
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+    self.tmpdir = tempfile.mkdtemp()
+    self._suffix = "_suffix"
+    self._SetUpKeys()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testReadAllPublicKeyFiles(self, mock_getalluserfiles):
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    keys = ssh.ReadLocalSshPubKeys([], suffix="")
+
+    self.assertEqual(len(constants.SSHK_ALL), len(keys))
+    for key_type in constants.SSHK_ALL:
+      self.assertTrue(
+          self._GenerateKey(key_type, self.VISIBILITY_PUBLIC) in keys)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testReadOnePublicKeyFile(self, mock_getalluserfiles):
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    keys = ssh.ReadLocalSshPubKeys([constants.SSHK_DSA], suffix="")
+
+    self.assertEqual(1, len(keys))
+    self.assertEqual(
+        self._GenerateKey(constants.SSHK_DSA, self.VISIBILITY_PUBLIC),
+        keys[0])
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testReadPublicKeyFilesWithSuffix(self, mock_getalluserfiles):
+    key_types = [constants.SSHK_DSA, constants.SSHK_ECDSA]
+
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    keys = ssh.ReadLocalSshPubKeys(key_types, suffix=self._suffix)
+
+    self.assertEqual(2, len(keys))
+    for key_id in [key_type + self._suffix for key_type in key_types]:
+      self.assertTrue(
+          self._GenerateKey(key_id, self.VISIBILITY_PUBLIC) in keys)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testGetSshKeyFilenames(self, mock_getalluserfiles):
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    priv, pub = ssh.GetSshKeyFilenames(constants.SSHK_DSA)
+
+    self.assertEqual("id_dsa", os.path.basename(priv))
+    self.assertNotEqual("id_dsa", priv)
+    self.assertEqual("id_dsa.pub", os.path.basename(pub))
+    self.assertNotEqual("id_dsa.pub", pub)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testGetSshKeyFilenamesWithSuffix(self, mock_getalluserfiles):
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    priv, pub = ssh.GetSshKeyFilenames(constants.SSHK_RSA, suffix=self._suffix)
+
+    self.assertEqual("id_rsa_suffix", os.path.basename(priv))
+    self.assertNotEqual("id_rsa_suffix", priv)
+    self.assertEqual("id_rsa_suffix.pub", os.path.basename(pub))
+    self.assertNotEqual("id_rsa_suffix.pub", pub)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testGetPubSshKeyFilename(self, mock_getalluserfiles):
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    pub = ssh.GetSshPubKeyFilename(constants.SSHK_DSA)
+    pub_suffix = ssh.GetSshPubKeyFilename(
+        constants.SSHK_DSA, suffix=self._suffix)
+
+    self.assertEqual("id_dsa.pub", os.path.basename(pub))
+    self.assertNotEqual("id_dsa.pub", pub)
+    self.assertEqual("id_dsa_suffix.pub", os.path.basename(pub_suffix))
+    self.assertNotEqual("id_dsa_suffix.pub", pub_suffix)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testReplaceSshKeys(self, mock_getalluserfiles):
+    """Replace SSH keys without suffixes.
+
+    Note: usually it does not really make sense to replace the DSA key
+    by the RSA key. This is just to test the function without suffixes.
+
+    """
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    ssh.ReplaceSshKeys(constants.SSHK_RSA, constants.SSHK_DSA)
+
+    priv_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][0])
+    pub_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][1])
+
+    self.assertEqual("I am the private rsa SSH key.", priv_key)
+    self.assertEqual("I am the public rsa SSH key.", pub_key)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testReplaceSshKeysBySuffixedKeys(self, mock_getalluserfiles):
+    """Replace SSH keys with keys from suffixed files.
+
+    Note: usually it does not really make sense to replace the DSA key
+    by the RSA key. This is just to test the function without suffixes.
+
+    """
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    ssh.ReplaceSshKeys(constants.SSHK_DSA, constants.SSHK_DSA,
+                       src_key_suffix=self._suffix)
+
+    priv_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][0])
+    pub_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][1])
+
+    self.assertEqual("I am the private dsa_suffix SSH key.", priv_key)
+    self.assertEqual("I am the public dsa_suffix SSH key.", pub_key)
+
+
 if __name__ == "__main__":
   testutils.GanetiTestProgram()