Skip to content
Snippets Groups Projects

Resolve "release v1.2.0"

Merged Ghost User requested to merge release_v1.2.0 into master
2 files
+ 75
35
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -37,7 +37,7 @@ class DataCollection(Iterable):
@@ -37,7 +37,7 @@ class DataCollection(Iterable):
if collection is None:
if collection is None:
collection = []
collection = []
assert isinstance(collection, list)
assert isinstance(collection, list)
self._collection = collection
self._collection = collection.copy()
self._mapping = {}
self._mapping = {}
self._set_mapping()
self._set_mapping()
self._name = name
self._name = name
@@ -119,9 +119,10 @@ class KerasIterator(keras.utils.Sequence):
@@ -119,9 +119,10 @@ class KerasIterator(keras.utils.Sequence):
def _get_batch(self, data_list: List[np.ndarray], b: int) -> List[np.ndarray]:
def _get_batch(self, data_list: List[np.ndarray], b: int) -> List[np.ndarray]:
"""Get batch according to batch size from data list."""
"""Get batch according to batch size from data list."""
return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list))
return list(map(lambda data: data[b * self.batch_size:(b + 1) * self.batch_size, ...], data_list))
def _permute_data(self, X, Y):
@staticmethod
 
def _permute_data(X, Y):
p = np.random.permutation(len(X[0])) # equiv to .shape[0]
p = np.random.permutation(len(X[0])) # equiv to .shape[0]
X = list(map(lambda x: x[p], X))
X = list(map(lambda x: x[p], X))
Y = list(map(lambda x: x[p], Y))
Y = list(map(lambda x: x[p], Y))
@@ -184,35 +185,3 @@ class KerasIterator(keras.utils.Sequence):
@@ -184,35 +185,3 @@ class KerasIterator(keras.utils.Sequence):
"""Randomly shuffle indexes if enabled."""
"""Randomly shuffle indexes if enabled."""
if self.shuffle is True:
if self.shuffle is True:
np.random.shuffle(self.indexes)
np.random.shuffle(self.indexes)
class DummyData: # pragma: no cover
def __init__(self, number_of_samples=np.random.randint(100, 150)):
self.number_of_samples = number_of_samples
def get_X(self):
X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables
X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2)) # samples, window, variables
X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2)) # samples, window, variables
return [X1, X2, X3]
def get_Y(self):
Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1)) # samples, window, variables
Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables
return [Y1, Y2]
if __name__ == "__main__":
collection = []
for _ in range(3):
collection.append(DummyData(50))
data_collection = DataCollection(collection=collection)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
iterator = KerasIterator(data_collection, 25, path, shuffle=True)
for data in data_collection:
print(data)
\ No newline at end of file
Loading