11from .. import DataProviderConverter
22import random
33
4- __all__ = ['DataProvider' , 'NaiveDataProvider' ]
4+ __all__ = ['DataProvider' , 'NaiveMemPooledDataProvider' , ' NaiveDataProvider' ]
55
66
77class DataProvider (object ):
88 __slots__ = [
9- '__init__' , 'reset' , 'next' , '__provider__ ' , '__converter__' ,
9+ '__init__' , 'reset' , 'next' , '__method__ ' , '__converter__' ,
1010 '__batch_size__' , '__should_shuffle__'
1111 ]
1212
13- def __init__ (self , provider , input_types , batch_size , should_shuffle = True ):
14- self .__provider__ = provider
13+ def __init__ (self , method , input_types , batch_size , should_shuffle = True ):
14+ self .__method__ = method
1515 self .__converter__ = DataProviderConverter (input_types )
1616 self .__batch_size__ = batch_size
17- if self .__provider__ .should_shuffle is None :
18- self .__provider__ .should_shuffle = should_shuffle
17+ self .__should_shuffle__ = should_shuffle
1918
2019 def reset (self ):
2120 raise NotImplemented ()
2221
2322 def next (self ):
2423 raise NotImplemented ()
2524
26- def __should_shuffle__ (self ):
27- return self .__provider__ .should_shuffle
28-
2925
30- class NaiveDataProvider (DataProvider ):
31- def __init__ (self , provider , input_types , batch_size , should_shuffle = True ):
32- super (NaiveDataProvider , self ).__init__ (
33- provider = provider ,
26+ class NaiveMemPooledDataProvider (DataProvider ):
27+ def __init__ (self , method , input_types , batch_size , should_shuffle ):
28+ super (NaiveMemPooledDataProvider , self ).__init__ (
29+ method = method ,
3430 input_types = input_types ,
3531 batch_size = batch_size ,
3632 should_shuffle = should_shuffle )
3733 self .__pool__ = []
3834 self .__idx__ = 0
3935
4036 def reset (self ):
41- def __to_pool__ ():
42- for filename in self .__provider__ .file_list :
43- for item in self .__provider__ .generator (self .__provider__ ,
44- filename ):
45- yield item
46-
47- self .__pool__ = list (__to_pool__ ())
48- if self .__should_shuffle__ ():
37+ self .__pool__ = list (self .__method__ ())
38+ if self .__should_shuffle__ :
4939 random .shuffle (self .__pool__ )
5040
5141 self .__idx__ = 0
@@ -58,3 +48,17 @@ def next(self):
5848 return self .__converter__ (self .__pool__ [begin :end ]), end - begin
5949 else :
6050 raise StopIteration
51+
52+
53+ class NaiveDataProvider (NaiveMemPooledDataProvider ):
54+ def __init__ (self , provider , input_types , batch_size , should_shuffle = True ):
55+ def __to_pool__ ():
56+ for filename in provider .file_list :
57+ for item in provider .generator (provider , filename ):
58+ yield item
59+
60+ super (NaiveDataProvider , self ).__init__ (
61+ method = __to_pool__ ,
62+ input_types = input_types ,
63+ batch_size = batch_size ,
64+ should_shuffle = should_shuffle )
0 commit comments