1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
   | class CifarData:     def __init__(self,filenames,need_shuffle):         all_data = []         all_labels = []         for filename in filenames:             data,labels = load_data(filename)             for item,label in zip(data,labels):                 if label in [0,1]:                     all_data.append(item)                     all_labels.append(label)         self._data = np.vstack(all_data)         self._labels = np.hstack(all_labels)         print(self._data.shape)         print(self._labels.shape)                  self._num_examples = self._data.shape[0]         self._need_shuffle = need_shuffle         self._indicator = 0         if self._need_shuffle:             self._shuffle_data()          def _shuffle_data(self):         # [0,1,2,3,4,5] - > [5,3,2,4,0,1]         p = np.random.permutation(self._num_examples)         self._data = self._data[p]         self._labels = self._labels[p]              def next_batch(self, batch_size):         """return batch_size examples as a batch."""         end_indicator = self._indicator + batch_size         if end_indicator > self._num_examples:             if self._need_shuffle:                 self._shuffle_data()                 self._indicator = 0                 end_indicator = batch_size             else:                 raise Exception("have no more examples")         if end_indicator > self._num_examples:             raise Exception("batch size is larger than all examples")         batch_data = self._data[self._indicator:end_indicator]         batch_labels = self._labels[self._indicator: end_indicator]
  train_filenames = [os.path.join(CIFAR_DIR,'data_batch_%d' % i ) for i in range(1,6)] test_filenames = [os.path.join(CIFAR_DIR,'test_batch')]
  train_data = CifarData(train_filenames,True)
   |