1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| class CifarData: def __init__(self,filenames,need_shuffle): all_data = [] all_labels = [] for filename in filenames: data,labels = load_data(filename) for item,label in zip(data,labels): if label in [0,1]: all_data.append(item) all_labels.append(label) self._data = np.vstack(all_data) self._labels = np.hstack(all_labels) print(self._data.shape) print(self._labels.shape) self._num_examples = self._data.shape[0] self._need_shuffle = need_shuffle self._indicator = 0 if self._need_shuffle: self._shuffle_data() def _shuffle_data(self): # [0,1,2,3,4,5] - > [5,3,2,4,0,1] p = np.random.permutation(self._num_examples) self._data = self._data[p] self._labels = self._labels[p] def next_batch(self, batch_size): """return batch_size examples as a batch.""" end_indicator = self._indicator + batch_size if end_indicator > self._num_examples: if self._need_shuffle: self._shuffle_data() self._indicator = 0 end_indicator = batch_size else: raise Exception("have no more examples") if end_indicator > self._num_examples: raise Exception("batch size is larger than all examples") batch_data = self._data[self._indicator:end_indicator] batch_labels = self._labels[self._indicator: end_indicator]
train_filenames = [os.path.join(CIFAR_DIR,'data_batch_%d' % i ) for i in range(1,6)] test_filenames = [os.path.join(CIFAR_DIR,'test_batch')]
train_data = CifarData(train_filenames,True)
|