Commit c5e803db authored by mova's avatar mova
Browse files

add the chunk type, fix skipping and shuffeling for the preprocesses seq

parent 41b4f9ab
......@@ -6,6 +6,7 @@ loaded depending on `conf.loader.name`.
import importlib
import os
from pathlib import Path
from typing import List, Tuple
import numpy as np
import torch
......@@ -24,6 +25,11 @@ process_seq = sel_seq.process_seq
files = sel_seq.files
len_dict = sel_seq.len_dict
chunksize = conf.loader.chunksize
batch_size = conf.loader.batch_size
ChunkType = List[Tuple[Path, int, int]]
class QueuedDataLoader:
"""
......@@ -33,41 +39,7 @@ must queue an epoch via `queue_epoch()` and iterate over the instance of the cla
"""
def __init__(self):
chunksize = conf.loader.chunksize
batch_size = conf.loader.batch_size
chunk_coords = [[]]
ifile = 0
ielement = 0
current_chunck_elements = 0
while ifile < len(files):
elem_left_in_cur_file = len_dict[files[ifile]] - ielement
elem_to_add = chunksize - current_chunck_elements
if elem_left_in_cur_file > elem_to_add:
chunk_coords[-1].append(
(files[ifile], ielement, ielement + elem_to_add)
)
ielement += elem_to_add
current_chunck_elements += elem_to_add
else:
chunk_coords[-1].append(
(files[ifile], ielement, ielement + elem_left_in_cur_file)
)
ielement = 0
current_chunck_elements += elem_left_in_cur_file
ifile += 1
if current_chunck_elements == chunksize:
current_chunck_elements = 0
chunk_coords.append([])
# remove the last, uneven chunk
chunk_coords = list(
filter(
lambda chunk: sum([part[2] - part[1] for part in chunk])
== chunksize,
chunk_coords,
)
)
chunk_coords = self._compute_chucks()
np.random.shuffle(chunk_coords)
......@@ -124,6 +96,41 @@ must queue an epoch via `queue_epoch()` and iterate over the instance of the cla
self.qfseq = qf.Sequence(*preprocessed_seq())
def _compute_chucks(self) -> List[ChunkType]:
chunk_coords: List[ChunkType] = [[]]
ifile = 0
ielement = 0
current_chunck_elements = 0
while ifile < len(files):
elem_left_in_cur_file = len_dict[str(files[ifile])] - ielement
elem_to_add = chunksize - current_chunck_elements
if elem_left_in_cur_file > elem_to_add:
chunk_coords[-1].append(
(files[ifile], ielement, ielement + elem_to_add)
)
ielement += elem_to_add
current_chunck_elements += elem_to_add
else:
chunk_coords[-1].append(
(files[ifile], ielement, ielement + elem_left_in_cur_file)
)
ielement = 0
current_chunck_elements += elem_left_in_cur_file
ifile += 1
if current_chunck_elements == chunksize:
current_chunck_elements = 0
chunk_coords.append([])
# remove the last, uneven chunk
chunk_coords = list(
filter(
lambda chunk: sum([part[2] - part[1] for part in chunk])
== chunksize,
chunk_coords,
)
)
return chunk_coords
@property
def validation_batches(self) -> DataSetType:
if not hasattr(self, "_validation_batches"):
......@@ -165,6 +172,7 @@ must queue an epoch via `queue_epoch()` and iterate over the instance of the cla
# Only queue to the chucks that are still left
epoch_chunks = self.training_chunks[n_skip_chunks:]
self.qfseq.queue_iterable(epoch_chunks)
np.random.shuffle(self.training_chunks)
# No calculate the number of batches that we still have to skip,
# because a chunk may be multiple batches and we need to skip
......@@ -177,11 +185,9 @@ must queue an epoch via `queue_epoch()` and iterate over the instance of the cla
f"""\
Skipping {n_skip_events} events => {n_skip_chunks} chunks and {n_skip_batches} batches."""
)
if n_skip_batches != 0:
for _ in range(n_skip_batches):
_ = next(self.qfseq)
logger.info(f"Skipped {n_skip_batches} batches.")
np.random.shuffle(self.training_chunks)
for _ in range(n_skip_batches):
_ = next(self.qfseq)
# Load the preprocessed batches
else:
......@@ -198,6 +204,7 @@ Skipping {n_skip_events} events => {n_skip_chunks} chunks and {n_skip_batches} b
)
epoch_files = self.preprocessed_files[n_skip_files:]
self.qfseq.queue_iterable(epoch_files)
np.random.shuffle(self.preprocessed_files)
# No calculate the number of batches that we still have to skip
n_skip_batches = (
......@@ -208,11 +215,12 @@ Skipping {n_skip_events} events => {n_skip_chunks} chunks and {n_skip_batches} b
f"""\
Skipping {n_skip_events} events => {n_skip_files} files and {n_skip_batches} batches."""
)
if n_skip_batches != 0:
for _ in range(n_skip_batches):
_ = next(self.qfseq)
logger.info(f"Skipped {n_skip_batches} batches.")
np.random.shuffle(self.preprocessed_files)
# Skip the correct number of batches.
for ibatch in range(n_skip_batches):
logger.info(f"Skipping batch({ibatch}).")
_ = next(self.qfseq)
logger.info(f"Skipped batch({ibatch}).")
def __iter__(self) -> qfseq:
return iter(self.qfseq)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment