diff --git a/bagit.py b/bagit.py index 69ea9ab..28c006d 100755 --- a/bagit.py +++ b/bagit.py @@ -138,13 +138,27 @@ def find_locale_dir(): UNICODE_BYTE_ORDER_MARK = "\ufeff" +def is_bag(bag_dir): + """ + Return a boolean whether the given directory is already a bag. + """ + try: + Bag(bag_dir) + return True + except BagError: + return False + + def make_bag( - bag_dir, bag_info=None, processes=1, checksums=None, checksum=None, encoding="utf-8" + bag_dir, bag_info=None, processes=1, checksums=None, checksum=None, encoding="utf-8", allow_nested_bag=False ): """ Convert a given directory into a bag. You can pass in arbitrary key/value pairs to put into the bag-info.txt metadata file as the bag_info dictionary. + + By default creating a bag of directory that is already a bag will raise an error. + Set allow_nested_bag to allow creation of nested bags. """ if checksum is not None: @@ -167,6 +181,13 @@ def make_bag( _("Bagging a parent of the current directory is not supported") ) + if not allow_nested_bag: + if is_bag(bag_dir): + raise RuntimeError( + _(f"The directory '{bag_dir}' is already a bag. " + "Use allow_nested_bag=True to allow creation of a nested bag.") + ) + LOGGER.info(_("Creating bag for directory %s"), bag_dir) if not os.path.isdir(bag_dir): diff --git a/test.py b/test.py index 735bd73..0ed122d 100644 --- a/test.py +++ b/test.py @@ -97,6 +97,31 @@ def test_make_bag_md5_sha1_sha256_manifest(self): # check valid with three manifests self.assertTrue(self.validate(bag, fast=True)) + def test_is_bag_false_initially(self): + self.assertFalse(bagit.is_bag(self.tmpdir)) + + def test_is_bag_true_after_make_bag(self): + bagit.make_bag(self.tmpdir) + self.assertTrue(bagit.is_bag(self.tmpdir)) + + def test_make_nested_bag_without_flag(self): + bagit.make_bag(self.tmpdir) + + with self.assertRaises(RuntimeError) as ctx: + tmpdir = self.tmpdir + bagit.make_bag(tmpdir, allow_nested_bag=False) + + expected_msg = (f"The directory '{tmpdir}' is already a bag. " + "Use allow_nested_bag=True to allow creation of a nested bag.") + self.assertEqual(str(ctx.exception), expected_msg) + + def test_make_nested_bag_with_flag(self): + bagit.make_bag(self.tmpdir) + + bag = bagit.make_bag(self.tmpdir, allow_nested_bag=True) + + self.assertIsInstance(bag, bagit.Bag) + def test_validate_flipped_bit(self): bag = bagit.make_bag(self.tmpdir) readme = j(self.tmpdir, "data", "README")