SSH utility functions for key manipulation
[ganeti-github.git] / lib / ssh.py
index a8fe86d..2c92264 100644 (file)
@@ -35,6 +35,7 @@
 
 import logging
 import os
+import shutil
 import tempfile
 
 from collections import namedtuple
@@ -1100,6 +1101,153 @@ def ReadRemoteSshPubKeys(pub_key_file, node, cluster_name, port, ask_key,
   return result.stdout
 
 
+def GetSshKeyFilenames(key_type, suffix=""):
+  """Get filenames of the SSH key pair of the given type.
+
+  @type key_type: string
+  @param key_type: type of SSH key, must be element of C{constants.SSHK_ALL}
+  @type suffix: string
+  @param suffix: optional suffix for the key filenames
+  @rtype: tuple of (string, string)
+  @returns: a tuple containing the name of the private key file and the
+       public key file.
+
+  """
+  if key_type not in constants.SSHK_ALL:
+    raise errors.SshUpdateError("Unsupported key type '%s'. Supported key types"
+                                " are: %s." % (key_type, constants.SSHK_ALL))
+  (_, root_keyfiles) = \
+      GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=False, dircheck=False)
+  if not key_type in root_keyfiles.keys():
+    raise errors.SshUpdateError("No keyfile for key type '%s' available."
+                                % key_type)
+
+  key_filenames = root_keyfiles[key_type]
+  if suffix:
+    key_filenames = [_ComputeKeyFilePathWithSuffix(key_filename, suffix)
+                     for key_filename in key_filenames]
+
+  return key_filenames
+
+
+def GetSshPubKeyFilename(key_type, suffix=""):
+  """Get filename of the public SSH key of the given type.
+
+  @type key_type: string
+  @param key_type: type of SSH key, must be element of C{constants.SSHK_ALL}
+  @type suffix: string
+  @param suffix: optional suffix for the key filenames
+  @rtype: string
+  @returns: file name of the public key file
+
+  """
+  return GetSshKeyFilenames(key_type, suffix=suffix)[1]
+
+
+def _ComputeKeyFilePathWithSuffix(key_filepath, suffix):
+  """Converts the given key filename to a key filename with a suffix.
+
+  @type key_filepath: string
+  @param key_filepath: path of the key file
+  @type suffix: string
+  @param suffix: suffix to be appended to the basename of the file
+
+  """
+  path = os.path.dirname(key_filepath)
+  ext = os.path.splitext(os.path.basename(key_filepath))[1]
+  basename = os.path.splitext(os.path.basename(key_filepath))[0]
+  return os.path.join(path, basename + suffix + ext)
+
+
+def ReplaceSshKeys(src_key_type, dest_key_type,
+                   src_key_suffix="", dest_key_suffix=""):
+  """Replaces an SSH key pair by another SSH key pair.
+
+  Note that both parts, the private and the public key, are replaced.
+
+  @type src_key_type: string
+  @param src_key_type: key type of key pair that is replacing the other
+      key pair
+  @type dest_key_type: string
+  @param dest_key_type: key type of the key pair that is being replaced
+      by the source key pair
+  @type src_key_suffix: string
+  @param src_key_suffix: optional suffix of the key files of the source
+      key pair
+  @type dest_key_suffix: string
+  @param dest_key_suffix: optional suffix of the keey files of the
+      destination key pair
+
+  """
+  (src_priv_filename, src_pub_filename) = GetSshKeyFilenames(
+      src_key_type, suffix=src_key_suffix)
+  (dest_priv_filename, dest_pub_filename) = GetSshKeyFilenames(
+      dest_key_type, suffix=dest_key_suffix)
+
+  if not (os.path.exists(src_priv_filename) and
+          os.path.exists(src_pub_filename)):
+    raise errors.SshUpdateError(
+        "At least one of the source key files is missing: %s",
+        ", ".join([src_priv_filename, src_pub_filename]))
+
+  for dest_file in [dest_priv_filename, dest_pub_filename]:
+    if os.path.exists(dest_file):
+      utils.CreateBackup(dest_file)
+      utils.RemoveFile(dest_file)
+
+  shutil.move(src_priv_filename, dest_priv_filename)
+  shutil.move(src_pub_filename, dest_pub_filename)
+
+
+def ReadLocalSshPubKeys(key_types, suffix=""):
+  """Reads the local root user SSH key.
+
+  @type key_types: list of string
+  @param key_types: types of SSH keys. Must be subset of constants.SSHK_ALL. If
+      'None' or [], all available keys are returned.
+  @type suffix: string
+  @param suffix: optional suffix to be attached to key names when reading
+      them. Used for temporary key files.
+  @rtype: list of string
+  @return: list of public keys
+
+  """
+  fetch_key_types = []
+  if key_types:
+    fetch_key_types += key_types
+  else:
+    fetch_key_types = constants.SSHK_ALL
+
+  (_, root_keyfiles) = \
+      GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=False, dircheck=False)
+
+  result_keys = []
+  for (public_key_type, (_, public_key_file)) in root_keyfiles.items():
+
+    if public_key_type not in fetch_key_types:
+      continue
+
+    public_key_dir = os.path.dirname(public_key_file)
+    public_key_filename = ""
+    if suffix:
+      public_key_filename = \
+          os.path.splitext(os.path.basename(public_key_file))[0] \
+          + suffix + ".pub"
+    else:
+      public_key_filename = public_key_file
+    public_key_path = os.path.join(public_key_dir,
+                                   public_key_filename)
+
+    if not os.path.exists(public_key_path):
+      raise errors.SshUpdateError("Cannot find SSH public key of type '%s'."
+                                  % public_key_type)
+    else:
+      key = utils.ReadFile(public_key_path)
+      result_keys.append(key)
+
+  return result_keys
+
+
 # Update gnt-cluster.rst when changing which combinations are valid.
 KeyBitInfo = namedtuple('KeyBitInfo', ['default', 'validation_fn'])
 SSH_KEY_VALID_BITS = {