diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py index 49569405a587920da795820d48f8d968a8142cc7..39e20020f4f80a872428681d53e2ec9f1a3dd3f7 100644 --- a/mlair/data_handler/iterator.py +++ b/mlair/data_handler/iterator.py @@ -55,7 +55,7 @@ class DataCollection(Iterable): def add(self, element): self._collection.append(element) - self._mapping[str(element)] = len(self._collection) + self._mapping[str(element)] = len(self._collection) - 1 def _set_mapping(self): for i, e in enumerate(self._collection): diff --git a/test/test_data_handler/test_iterator.py b/test/test_data_handler/test_iterator.py index ec224c06e358297972097f2cc75cea86f768784f..678f3d369d4b6424f94557d7d739fc65a995aacc 100644 --- a/test/test_data_handler/test_iterator.py +++ b/test/test_data_handler/test_iterator.py @@ -52,6 +52,13 @@ class TestDataCollection: for e, i in enumerate(data_collection): assert i == e + def test_add(self): + data_collection = DataCollection() + data_collection.add("first_element") + assert len(data_collection) == 1 + assert data_collection["first_element"] == "first_element" + assert data_collection[0] == "first_element" + class DummyData: