import logging
import os
+import shutil
import tempfile
from collections import namedtuple
return result.stdout
+def GetSshKeyFilenames(key_type, suffix=""):
+ """Get filenames of the SSH key pair of the given type.
+
+ @type key_type: string
+ @param key_type: type of SSH key, must be element of C{constants.SSHK_ALL}
+ @type suffix: string
+ @param suffix: optional suffix for the key filenames
+ @rtype: tuple of (string, string)
+ @returns: a tuple containing the name of the private key file and the
+ public key file.
+
+ """
+ if key_type not in constants.SSHK_ALL:
+ raise errors.SshUpdateError("Unsupported key type '%s'. Supported key types"
+ " are: %s." % (key_type, constants.SSHK_ALL))
+ (_, root_keyfiles) = \
+ GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=False, dircheck=False)
+ if not key_type in root_keyfiles.keys():
+ raise errors.SshUpdateError("No keyfile for key type '%s' available."
+ % key_type)
+
+ key_filenames = root_keyfiles[key_type]
+ if suffix:
+ key_filenames = [_ComputeKeyFilePathWithSuffix(key_filename, suffix)
+ for key_filename in key_filenames]
+
+ return key_filenames
+
+
+def GetSshPubKeyFilename(key_type, suffix=""):
+ """Get filename of the public SSH key of the given type.
+
+ @type key_type: string
+ @param key_type: type of SSH key, must be element of C{constants.SSHK_ALL}
+ @type suffix: string
+ @param suffix: optional suffix for the key filenames
+ @rtype: string
+ @returns: file name of the public key file
+
+ """
+ return GetSshKeyFilenames(key_type, suffix=suffix)[1]
+
+
+def _ComputeKeyFilePathWithSuffix(key_filepath, suffix):
+ """Converts the given key filename to a key filename with a suffix.
+
+ @type key_filepath: string
+ @param key_filepath: path of the key file
+ @type suffix: string
+ @param suffix: suffix to be appended to the basename of the file
+
+ """
+ path = os.path.dirname(key_filepath)
+ ext = os.path.splitext(os.path.basename(key_filepath))[1]
+ basename = os.path.splitext(os.path.basename(key_filepath))[0]
+ return os.path.join(path, basename + suffix + ext)
+
+
+def ReplaceSshKeys(src_key_type, dest_key_type,
+ src_key_suffix="", dest_key_suffix=""):
+ """Replaces an SSH key pair by another SSH key pair.
+
+ Note that both parts, the private and the public key, are replaced.
+
+ @type src_key_type: string
+ @param src_key_type: key type of key pair that is replacing the other
+ key pair
+ @type dest_key_type: string
+ @param dest_key_type: key type of the key pair that is being replaced
+ by the source key pair
+ @type src_key_suffix: string
+ @param src_key_suffix: optional suffix of the key files of the source
+ key pair
+ @type dest_key_suffix: string
+ @param dest_key_suffix: optional suffix of the keey files of the
+ destination key pair
+
+ """
+ (src_priv_filename, src_pub_filename) = GetSshKeyFilenames(
+ src_key_type, suffix=src_key_suffix)
+ (dest_priv_filename, dest_pub_filename) = GetSshKeyFilenames(
+ dest_key_type, suffix=dest_key_suffix)
+
+ if not (os.path.exists(src_priv_filename) and
+ os.path.exists(src_pub_filename)):
+ raise errors.SshUpdateError(
+ "At least one of the source key files is missing: %s",
+ ", ".join([src_priv_filename, src_pub_filename]))
+
+ for dest_file in [dest_priv_filename, dest_pub_filename]:
+ if os.path.exists(dest_file):
+ utils.CreateBackup(dest_file)
+ utils.RemoveFile(dest_file)
+
+ shutil.move(src_priv_filename, dest_priv_filename)
+ shutil.move(src_pub_filename, dest_pub_filename)
+
+
+def ReadLocalSshPubKeys(key_types, suffix=""):
+ """Reads the local root user SSH key.
+
+ @type key_types: list of string
+ @param key_types: types of SSH keys. Must be subset of constants.SSHK_ALL. If
+ 'None' or [], all available keys are returned.
+ @type suffix: string
+ @param suffix: optional suffix to be attached to key names when reading
+ them. Used for temporary key files.
+ @rtype: list of string
+ @return: list of public keys
+
+ """
+ fetch_key_types = []
+ if key_types:
+ fetch_key_types += key_types
+ else:
+ fetch_key_types = constants.SSHK_ALL
+
+ (_, root_keyfiles) = \
+ GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=False, dircheck=False)
+
+ result_keys = []
+ for (public_key_type, (_, public_key_file)) in root_keyfiles.items():
+
+ if public_key_type not in fetch_key_types:
+ continue
+
+ public_key_dir = os.path.dirname(public_key_file)
+ public_key_filename = ""
+ if suffix:
+ public_key_filename = \
+ os.path.splitext(os.path.basename(public_key_file))[0] \
+ + suffix + ".pub"
+ else:
+ public_key_filename = public_key_file
+ public_key_path = os.path.join(public_key_dir,
+ public_key_filename)
+
+ if not os.path.exists(public_key_path):
+ raise errors.SshUpdateError("Cannot find SSH public key of type '%s'."
+ % public_key_type)
+ else:
+ key = utils.ReadFile(public_key_path)
+ result_keys.append(key)
+
+ return result_keys
+
+
# Update gnt-cluster.rst when changing which combinations are valid.
KeyBitInfo = namedtuple('KeyBitInfo', ['default', 'validation_fn'])
SSH_KEY_VALID_BITS = {
self.assertEquals(b, ssh.DetermineKeyBits("rsa", b, None, None))
+class TestManageLocalSshPubKeys(testutils.GanetiTestCase):
+ """Test class for several methods handling local SSH keys.
+
+ Methods covered are:
+ - GetSshKeyFilenames
+ - GetSshPubKeyFilename
+ - ReplaceSshKeys
+ - ReadLocalSshPubKeys
+
+ These methods are covered in one test, because the preparations for
+ their tests is identical and thus can be reused.
+
+ """
+ VISIBILITY_PRIVATE = "private"
+ VISIBILITY_PUBLIC = "public"
+ VISIBILITIES = frozenset([VISIBILITY_PRIVATE, VISIBILITY_PUBLIC])
+
+ def _GenerateKey(self, key_id, visibility):
+ assert visibility in self.VISIBILITIES
+ return "I am the %s %s SSH key." % (visibility, key_id)
+
+ def _GetKeyPath(self, key_file_basename):
+ return os.path.join(self.tmpdir, key_file_basename)
+
+ def _SetUpKeys(self):
+ """Creates a fake SSH key for each type and with/without suffix."""
+ self._key_file_dict = {}
+ for key_type in constants.SSHK_ALL:
+ for suffix in ["", self._suffix]:
+ pub_key_filename = "id_%s%s.pub" % (key_type, suffix)
+ priv_key_filename = "id_%s%s" % (key_type, suffix)
+
+ pub_key_path = self._GetKeyPath(pub_key_filename)
+ priv_key_path = self._GetKeyPath(priv_key_filename)
+
+ utils.WriteFile(
+ priv_key_path,
+ data=self._GenerateKey(key_type + suffix, self.VISIBILITY_PRIVATE))
+
+ utils.WriteFile(
+ pub_key_path,
+ data=self._GenerateKey(key_type + suffix, self.VISIBILITY_PUBLIC))
+
+ # Fill key dict only for non-suffix keys
+ # (as this is how it will be in the code)
+ if not suffix:
+ self._key_file_dict[key_type] = \
+ (priv_key_path, pub_key_path)
+
+ def setUp(self):
+ testutils.GanetiTestCase.setUp(self)
+ self.tmpdir = tempfile.mkdtemp()
+ self._suffix = "_suffix"
+ self._SetUpKeys()
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdir)
+
+ @testutils.patch_object(ssh, "GetAllUserFiles")
+ def testReadAllPublicKeyFiles(self, mock_getalluserfiles):
+ mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+ keys = ssh.ReadLocalSshPubKeys([], suffix="")
+
+ self.assertEqual(len(constants.SSHK_ALL), len(keys))
+ for key_type in constants.SSHK_ALL:
+ self.assertTrue(
+ self._GenerateKey(key_type, self.VISIBILITY_PUBLIC) in keys)
+
+ @testutils.patch_object(ssh, "GetAllUserFiles")
+ def testReadOnePublicKeyFile(self, mock_getalluserfiles):
+ mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+ keys = ssh.ReadLocalSshPubKeys([constants.SSHK_DSA], suffix="")
+
+ self.assertEqual(1, len(keys))
+ self.assertEqual(
+ self._GenerateKey(constants.SSHK_DSA, self.VISIBILITY_PUBLIC),
+ keys[0])
+
+ @testutils.patch_object(ssh, "GetAllUserFiles")
+ def testReadPublicKeyFilesWithSuffix(self, mock_getalluserfiles):
+ key_types = [constants.SSHK_DSA, constants.SSHK_ECDSA]
+
+ mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+ keys = ssh.ReadLocalSshPubKeys(key_types, suffix=self._suffix)
+
+ self.assertEqual(2, len(keys))
+ for key_id in [key_type + self._suffix for key_type in key_types]:
+ self.assertTrue(
+ self._GenerateKey(key_id, self.VISIBILITY_PUBLIC) in keys)
+
+ @testutils.patch_object(ssh, "GetAllUserFiles")
+ def testGetSshKeyFilenames(self, mock_getalluserfiles):
+ mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+ priv, pub = ssh.GetSshKeyFilenames(constants.SSHK_DSA)
+
+ self.assertEqual("id_dsa", os.path.basename(priv))
+ self.assertNotEqual("id_dsa", priv)
+ self.assertEqual("id_dsa.pub", os.path.basename(pub))
+ self.assertNotEqual("id_dsa.pub", pub)
+
+ @testutils.patch_object(ssh, "GetAllUserFiles")
+ def testGetSshKeyFilenamesWithSuffix(self, mock_getalluserfiles):
+ mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+ priv, pub = ssh.GetSshKeyFilenames(constants.SSHK_RSA, suffix=self._suffix)
+
+ self.assertEqual("id_rsa_suffix", os.path.basename(priv))
+ self.assertNotEqual("id_rsa_suffix", priv)
+ self.assertEqual("id_rsa_suffix.pub", os.path.basename(pub))
+ self.assertNotEqual("id_rsa_suffix.pub", pub)
+
+ @testutils.patch_object(ssh, "GetAllUserFiles")
+ def testGetPubSshKeyFilename(self, mock_getalluserfiles):
+ mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+ pub = ssh.GetSshPubKeyFilename(constants.SSHK_DSA)
+ pub_suffix = ssh.GetSshPubKeyFilename(
+ constants.SSHK_DSA, suffix=self._suffix)
+
+ self.assertEqual("id_dsa.pub", os.path.basename(pub))
+ self.assertNotEqual("id_dsa.pub", pub)
+ self.assertEqual("id_dsa_suffix.pub", os.path.basename(pub_suffix))
+ self.assertNotEqual("id_dsa_suffix.pub", pub_suffix)
+
+ @testutils.patch_object(ssh, "GetAllUserFiles")
+ def testReplaceSshKeys(self, mock_getalluserfiles):
+ """Replace SSH keys without suffixes.
+
+ Note: usually it does not really make sense to replace the DSA key
+ by the RSA key. This is just to test the function without suffixes.
+
+ """
+ mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+ ssh.ReplaceSshKeys(constants.SSHK_RSA, constants.SSHK_DSA)
+
+ priv_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][0])
+ pub_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][1])
+
+ self.assertEqual("I am the private rsa SSH key.", priv_key)
+ self.assertEqual("I am the public rsa SSH key.", pub_key)
+
+ @testutils.patch_object(ssh, "GetAllUserFiles")
+ def testReplaceSshKeysBySuffixedKeys(self, mock_getalluserfiles):
+ """Replace SSH keys with keys from suffixed files.
+
+ Note: usually it does not really make sense to replace the DSA key
+ by the RSA key. This is just to test the function without suffixes.
+
+ """
+ mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+ ssh.ReplaceSshKeys(constants.SSHK_DSA, constants.SSHK_DSA,
+ src_key_suffix=self._suffix)
+
+ priv_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][0])
+ pub_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][1])
+
+ self.assertEqual("I am the private dsa_suffix SSH key.", priv_key)
+ self.assertEqual("I am the public dsa_suffix SSH key.", pub_key)
+
+
if __name__ == "__main__":
testutils.GanetiTestProgram()