1818use clap:: Parser ;
1919use datafusion:: error:: { DataFusionError , Result } ;
2020use datafusion:: execution:: context:: SessionConfig ;
21+ use datafusion:: execution:: memory_pool:: { FairSpillPool , GreedyMemoryPool } ;
2122use datafusion:: execution:: runtime_env:: { RuntimeConfig , RuntimeEnv } ;
2223use datafusion:: prelude:: SessionContext ;
2324use datafusion_cli:: catalog:: DynamicFileCatalog ;
@@ -27,11 +28,30 @@ use datafusion_cli::{
2728use mimalloc:: MiMalloc ;
2829use std:: env;
2930use std:: path:: Path ;
31+ use std:: str:: FromStr ;
3032use std:: sync:: Arc ;
3133
3234#[ global_allocator]
3335static GLOBAL : MiMalloc = MiMalloc ;
3436
37+ #[ derive( PartialEq , Debug ) ]
38+ enum PoolType {
39+ Greedy ,
40+ Fair ,
41+ }
42+
43+ impl FromStr for PoolType {
44+ type Err = String ;
45+
46+ fn from_str ( s : & str ) -> Result < Self , Self :: Err > {
47+ match s {
48+ "Greedy" | "greedy" => Ok ( PoolType :: Greedy ) ,
49+ "Fair" | "fair" => Ok ( PoolType :: Fair ) ,
50+ _ => Err ( format ! ( "Invalid memory pool type '{}'" , s) ) ,
51+ }
52+ }
53+ }
54+
3555#[ derive( Debug , Parser , PartialEq ) ]
3656#[ clap( author, version, about, long_about= None ) ]
3757struct Args {
@@ -59,6 +79,14 @@ struct Args {
5979 ) ]
6080 command : Vec < String > ,
6181
82+ #[ clap(
83+ short = 'm' ,
84+ long,
85+ help = "The memory pool limitation (e.g. '10g'), default to None (no limit)" ,
86+ validator( is_valid_memory_pool_size)
87+ ) ]
88+ memory_limit : Option < String > ,
89+
6290 #[ clap(
6391 short,
6492 long,
@@ -87,6 +115,12 @@ struct Args {
87115 help = "Reduce printing other than the results and work quietly"
88116 ) ]
89117 quiet : bool ,
118+
119+ #[ clap(
120+ long,
121+ help = "Specify the memory pool type 'greedy' or 'fair', default to 'greedy'"
122+ ) ]
123+ mem_pool_type : Option < PoolType > ,
90124}
91125
92126#[ tokio:: main]
@@ -109,7 +143,29 @@ pub async fn main() -> Result<()> {
109143 session_config = session_config. with_batch_size ( batch_size) ;
110144 } ;
111145
112- let runtime_env = create_runtime_env ( ) ?;
146+ let rn_config = RuntimeConfig :: new ( ) ;
147+ let rn_config =
148+ // set memory pool size
149+ if let Some ( memory_limit) = args. memory_limit {
150+ let memory_limit = extract_memory_pool_size ( & memory_limit) . unwrap ( ) ;
151+ // set memory pool type
152+ if let Some ( mem_pool_type) = args. mem_pool_type {
153+ match mem_pool_type {
154+ PoolType :: Greedy => rn_config
155+ . with_memory_pool ( Arc :: new ( GreedyMemoryPool :: new ( memory_limit) ) ) ,
156+ PoolType :: Fair => rn_config
157+ . with_memory_pool ( Arc :: new ( FairSpillPool :: new ( memory_limit) ) ) ,
158+ }
159+ } else {
160+ rn_config
161+ . with_memory_pool ( Arc :: new ( GreedyMemoryPool :: new ( memory_limit) ) )
162+ }
163+ } else {
164+ rn_config
165+ } ;
166+
167+ let runtime_env = create_runtime_env ( rn_config. clone ( ) ) ?;
168+
113169 let mut ctx =
114170 SessionContext :: with_config_rt ( session_config. clone ( ) , Arc :: new ( runtime_env) ) ;
115171 ctx. refresh_catalogs ( ) . await ?;
@@ -162,8 +218,7 @@ pub async fn main() -> Result<()> {
162218 Ok ( ( ) )
163219}
164220
165- fn create_runtime_env ( ) -> Result < RuntimeEnv > {
166- let rn_config = RuntimeConfig :: new ( ) ;
221+ fn create_runtime_env ( rn_config : RuntimeConfig ) -> Result < RuntimeEnv > {
167222 RuntimeEnv :: new ( rn_config)
168223}
169224
@@ -189,3 +244,34 @@ fn is_valid_batch_size(size: &str) -> Result<(), String> {
189244 _ => Err ( format ! ( "Invalid batch size '{}'" , size) ) ,
190245 }
191246}
247+
248+ fn is_valid_memory_pool_size ( size : & str ) -> Result < ( ) , String > {
249+ match extract_memory_pool_size ( size) {
250+ Ok ( _) => Ok ( ( ) ) ,
251+ Err ( e) => Err ( e) ,
252+ }
253+ }
254+
255+ fn extract_memory_pool_size ( size : & str ) -> Result < usize , String > {
256+ let mut size = size;
257+ let factor = if let Some ( last_char) = size. chars ( ) . last ( ) {
258+ match last_char {
259+ 'm' | 'M' => {
260+ size = & size[ ..size. len ( ) - 1 ] ;
261+ 1024 * 1024
262+ }
263+ 'g' | 'G' => {
264+ size = & size[ ..size. len ( ) - 1 ] ;
265+ 1024 * 1024 * 1024
266+ }
267+ _ => 1 ,
268+ }
269+ } else {
270+ return Err ( format ! ( "Invalid memory pool size '{}'" , size) ) ;
271+ } ;
272+
273+ match size. parse :: < usize > ( ) {
274+ Ok ( size) if size > 0 => Ok ( factor * size) ,
275+ _ => Err ( format ! ( "Invalid memory pool size '{}'" , size) ) ,
276+ }
277+ }
0 commit comments