SSH utility functions for key manipulation
[ganeti-github.git] / test / py / ganeti.ssh_unittest.py
index 3510992..661245b 100755 (executable)
@@ -521,5 +521,171 @@ class TestDetermineKeyBits(testutils.GanetiTestCase):
       self.assertEquals(b, ssh.DetermineKeyBits("rsa", b, None, None))
 
 
+class TestManageLocalSshPubKeys(testutils.GanetiTestCase):
+  """Test class for several methods handling local SSH keys.
+
+  Methods covered are:
+  - GetSshKeyFilenames
+  - GetSshPubKeyFilename
+  - ReplaceSshKeys
+  - ReadLocalSshPubKeys
+
+  These methods are covered in one test, because the preparations for
+  their tests is identical and thus can be reused.
+
+  """
+  VISIBILITY_PRIVATE = "private"
+  VISIBILITY_PUBLIC = "public"
+  VISIBILITIES = frozenset([VISIBILITY_PRIVATE, VISIBILITY_PUBLIC])
+
+  def _GenerateKey(self, key_id, visibility):
+    assert visibility in self.VISIBILITIES
+    return "I am the %s %s SSH key." % (visibility, key_id)
+
+  def _GetKeyPath(self, key_file_basename):
+     return os.path.join(self.tmpdir, key_file_basename)
+
+  def _SetUpKeys(self):
+    """Creates a fake SSH key for each type and with/without suffix."""
+    self._key_file_dict = {}
+    for key_type in constants.SSHK_ALL:
+      for suffix in ["", self._suffix]:
+        pub_key_filename = "id_%s%s.pub" % (key_type, suffix)
+        priv_key_filename = "id_%s%s" % (key_type, suffix)
+
+        pub_key_path = self._GetKeyPath(pub_key_filename)
+        priv_key_path = self._GetKeyPath(priv_key_filename)
+
+        utils.WriteFile(
+            priv_key_path,
+            data=self._GenerateKey(key_type + suffix, self.VISIBILITY_PRIVATE))
+
+        utils.WriteFile(
+            pub_key_path,
+            data=self._GenerateKey(key_type + suffix, self.VISIBILITY_PUBLIC))
+
+        # Fill key dict only for non-suffix keys
+        # (as this is how it will be in the code)
+        if not suffix:
+          self._key_file_dict[key_type] = \
+            (priv_key_path, pub_key_path)
+
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+    self.tmpdir = tempfile.mkdtemp()
+    self._suffix = "_suffix"
+    self._SetUpKeys()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testReadAllPublicKeyFiles(self, mock_getalluserfiles):
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    keys = ssh.ReadLocalSshPubKeys([], suffix="")
+
+    self.assertEqual(len(constants.SSHK_ALL), len(keys))
+    for key_type in constants.SSHK_ALL:
+      self.assertTrue(
+          self._GenerateKey(key_type, self.VISIBILITY_PUBLIC) in keys)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testReadOnePublicKeyFile(self, mock_getalluserfiles):
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    keys = ssh.ReadLocalSshPubKeys([constants.SSHK_DSA], suffix="")
+
+    self.assertEqual(1, len(keys))
+    self.assertEqual(
+        self._GenerateKey(constants.SSHK_DSA, self.VISIBILITY_PUBLIC),
+        keys[0])
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testReadPublicKeyFilesWithSuffix(self, mock_getalluserfiles):
+    key_types = [constants.SSHK_DSA, constants.SSHK_ECDSA]
+
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    keys = ssh.ReadLocalSshPubKeys(key_types, suffix=self._suffix)
+
+    self.assertEqual(2, len(keys))
+    for key_id in [key_type + self._suffix for key_type in key_types]:
+      self.assertTrue(
+          self._GenerateKey(key_id, self.VISIBILITY_PUBLIC) in keys)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testGetSshKeyFilenames(self, mock_getalluserfiles):
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    priv, pub = ssh.GetSshKeyFilenames(constants.SSHK_DSA)
+
+    self.assertEqual("id_dsa", os.path.basename(priv))
+    self.assertNotEqual("id_dsa", priv)
+    self.assertEqual("id_dsa.pub", os.path.basename(pub))
+    self.assertNotEqual("id_dsa.pub", pub)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testGetSshKeyFilenamesWithSuffix(self, mock_getalluserfiles):
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    priv, pub = ssh.GetSshKeyFilenames(constants.SSHK_RSA, suffix=self._suffix)
+
+    self.assertEqual("id_rsa_suffix", os.path.basename(priv))
+    self.assertNotEqual("id_rsa_suffix", priv)
+    self.assertEqual("id_rsa_suffix.pub", os.path.basename(pub))
+    self.assertNotEqual("id_rsa_suffix.pub", pub)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testGetPubSshKeyFilename(self, mock_getalluserfiles):
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    pub = ssh.GetSshPubKeyFilename(constants.SSHK_DSA)
+    pub_suffix = ssh.GetSshPubKeyFilename(
+        constants.SSHK_DSA, suffix=self._suffix)
+
+    self.assertEqual("id_dsa.pub", os.path.basename(pub))
+    self.assertNotEqual("id_dsa.pub", pub)
+    self.assertEqual("id_dsa_suffix.pub", os.path.basename(pub_suffix))
+    self.assertNotEqual("id_dsa_suffix.pub", pub_suffix)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testReplaceSshKeys(self, mock_getalluserfiles):
+    """Replace SSH keys without suffixes.
+
+    Note: usually it does not really make sense to replace the DSA key
+    by the RSA key. This is just to test the function without suffixes.
+
+    """
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    ssh.ReplaceSshKeys(constants.SSHK_RSA, constants.SSHK_DSA)
+
+    priv_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][0])
+    pub_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][1])
+
+    self.assertEqual("I am the private rsa SSH key.", priv_key)
+    self.assertEqual("I am the public rsa SSH key.", pub_key)
+
+  @testutils.patch_object(ssh, "GetAllUserFiles")
+  def testReplaceSshKeysBySuffixedKeys(self, mock_getalluserfiles):
+    """Replace SSH keys with keys from suffixed files.
+
+    Note: usually it does not really make sense to replace the DSA key
+    by the RSA key. This is just to test the function without suffixes.
+
+    """
+    mock_getalluserfiles.return_value = (None, self._key_file_dict)
+
+    ssh.ReplaceSshKeys(constants.SSHK_DSA, constants.SSHK_DSA,
+                       src_key_suffix=self._suffix)
+
+    priv_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][0])
+    pub_key = utils.ReadFile(self._key_file_dict[constants.SSHK_DSA][1])
+
+    self.assertEqual("I am the private dsa_suffix SSH key.", priv_key)
+    self.assertEqual("I am the public dsa_suffix SSH key.", pub_key)
+
+
 if __name__ == "__main__":
   testutils.GanetiTestProgram()