1+ import { AsyncLocalStorage } from 'node:async_hooks' ;
12import { promisify } from 'node:util' ;
23import mysql from 'mysql' ;
3- import type { PoolConfig , Pool } from 'mysql' ;
4- import type { PoolConnectionPromisify } from './types' ;
4+ import type { Pool } from 'mysql' ;
5+ import type { PoolConnectionPromisify , RDSClientOptions , TransactionContext , TransactionScope } from './types' ;
56import { Operator } from './operator' ;
67import { RDSConnection } from './connection' ;
78import { RDSTransaction } from './transaction' ;
@@ -24,17 +25,26 @@ export class RDSClient extends Operator {
2425 static get format ( ) { return mysql . format ; }
2526 static get raw ( ) { return mysql . raw ; }
2627
28+ static #DEFAULT_STORAGE_KEY = Symbol ( 'RDSClient#storage#default' ) ;
29+ static #TRANSACTION_NEST_COUNT = Symbol ( 'RDSClient#transaction#nestCount' ) ;
30+
2731 #pool: PoolPromisify ;
28- constructor ( options : PoolConfig ) {
32+ #connectionStorage: AsyncLocalStorage < TransactionContext > ;
33+ #connectionStorageKey: string | symbol ;
34+
35+ constructor ( options : RDSClientOptions ) {
2936 super ( ) ;
30- this . #pool = mysql . createPool ( options ) as unknown as PoolPromisify ;
37+ const { connectionStorage, connectionStorageKey, ...mysqlOptions } = options ;
38+ this . #pool = mysql . createPool ( mysqlOptions ) as unknown as PoolPromisify ;
3139 [
3240 'query' ,
3341 'getConnection' ,
3442 'end' ,
3543 ] . forEach ( method => {
3644 this . #pool[ method ] = promisify ( this . #pool[ method ] ) ;
3745 } ) ;
46+ this . #connectionStorage = connectionStorage || new AsyncLocalStorage ( ) ;
47+ this . #connectionStorageKey = connectionStorageKey || RDSClient . #DEFAULT_STORAGE_KEY;
3848 }
3949
4050 // impl Operator._query
@@ -92,6 +102,7 @@ export class RDSClient extends Operator {
92102 throw err ;
93103 }
94104 const tran = new RDSTransaction ( conn ) ;
105+ tran [ RDSClient . #TRANSACTION_NEST_COUNT] = 1 ;
95106 if ( this . beforeQueryHandlers . length > 0 ) {
96107 for ( const handler of this . beforeQueryHandlers ) {
97108 tran . beforeQuery ( handler ) ;
@@ -109,75 +120,137 @@ export class RDSClient extends Operator {
109120 * Auto commit or rollback on a transaction scope
110121 *
111122 * @param {Function } scope - scope with code
112- * @param {Object } [ctx] - transaction env context, like koa's ctx.
113- * To make sure only one active transaction on this ctx.
123+ * @param {Object } [ctx] - transaction context
114124 * @return {Object } - scope return result
115125 */
116- async beginTransactionScope ( scope : ( transaction : RDSTransaction ) => Promise < any > , ctx ?: any ) : Promise < any > {
117- ctx = ctx || { } ;
118- if ( ! ctx . _transactionConnection ) {
119- // Create only one conn if concurrent call `beginTransactionScope`
120- ctx . _transactionConnection = this . beginTransaction ( ) ;
121- }
122- const tran = await ctx . _transactionConnection ;
123-
124- if ( ! ctx . _transactionScopeCount ) {
125- ctx . _transactionScopeCount = 1 ;
126+ async #beginTransactionScope( scope : TransactionScope , ctx : TransactionContext ) : Promise < any > {
127+ let tran : RDSTransaction ;
128+ let shouldRelease = false ;
129+ if ( ! ctx [ this . #connectionStorageKey] ) {
130+ // there is no transaction in ctx, create a new one
131+ tran = await this . beginTransaction ( ) ;
132+ ctx [ this . #connectionStorageKey] = tran ;
133+ shouldRelease = true ;
126134 } else {
127- ctx . _transactionScopeCount ++ ;
135+ // use transaction in ctx
136+ tran = ctx [ this . #connectionStorageKey] ! ;
137+ tran [ RDSClient . #TRANSACTION_NEST_COUNT] ++ ;
128138 }
139+
140+ let result : any ;
141+ let scopeError : any ;
142+ let internalError : any ;
129143 try {
130- const result = await scope ( tran ) ;
131- ctx . _transactionScopeCount -- ;
132- if ( ctx . _transactionScopeCount === 0 ) {
133- ctx . _transactionConnection = null ;
134- await tran . commit ( ) ;
144+ result = await scope ( tran ) ;
145+ } catch ( err : any ) {
146+ scopeError = err ;
147+ }
148+ tran [ RDSClient . #TRANSACTION_NEST_COUNT] -- ;
149+
150+ // null connection means the nested scope has been rollback, we can do nothing here
151+ if ( tran . conn ) {
152+ try {
153+ // execution error, should rollback
154+ if ( scopeError ) {
155+ await tran . rollback ( ) ;
156+ } else if ( tran [ RDSClient . #TRANSACTION_NEST_COUNT] < 1 ) {
157+ // nestedCount smaller than 1 means all the nested scopes have executed successfully
158+ await tran . commit ( ) ;
159+ }
160+ } catch ( err ) {
161+ internalError = err ;
135162 }
136- return result ;
137- } catch ( err ) {
138- if ( ctx . _transactionConnection ) {
139- ctx . _transactionConnection = null ;
140- await tran . rollback ( ) ;
163+ }
164+
165+ // remove transaction in ctx
166+ if ( shouldRelease && tran [ RDSClient . #TRANSACTION_NEST_COUNT] < 1 ) {
167+ ctx [ this . #connectionStorageKey] = null ;
168+ }
169+
170+ if ( internalError ) {
171+ if ( scopeError ) {
172+ internalError . cause = scopeError ;
141173 }
142- throw err ;
174+ throw internalError ;
175+ }
176+ if ( scopeError ) {
177+ throw scopeError ;
143178 }
179+ return result ;
180+ }
181+
182+ /**
183+ * Auto commit or rollback on a transaction scope
184+ *
185+ * @param scope - scope with code
186+ * @return {Object } - scope return result
187+ */
188+ async beginTransactionScope ( scope : TransactionScope ) {
189+ let ctx = this . #connectionStorage. getStore ( ) ;
190+ if ( ctx ) {
191+ return await this . #beginTransactionScope( scope , ctx ) ;
192+ }
193+ ctx = { } ;
194+ return await this . #connectionStorage. run ( ctx , async ( ) => {
195+ return await this . #beginTransactionScope( scope , ctx ! ) ;
196+ } ) ;
144197 }
145198
146199 /**
147200 * doomed to be rollbacked after transaction scope
148201 * useful on writing tests which are related with database
149202 *
150- * @param {Function } scope - scope with code
151- * @param {Object } [ctx] - transaction env context, like koa's ctx.
152- * To make sure only one active transaction on this ctx.
203+ * @param scope - scope with code
204+ * @param ctx - transaction context
153205 * @return {Object } - scope return result
154206 */
155- async beginDoomedTransactionScope ( scope : ( transaction : RDSTransaction ) => Promise < any > , ctx ?: any ) : Promise < any > {
156- ctx = ctx || { } ;
157- if ( ! ctx . _transactionConnection ) {
158- ctx . _transactionConnection = await this . beginTransaction ( ) ;
159- ctx . _transactionScopeCount = 1 ;
207+ async #beginDoomedTransactionScope( scope : TransactionScope , ctx : TransactionContext ) : Promise < any > {
208+ let tran : RDSTransaction ;
209+ if ( ! ctx [ this . #connectionStorageKey] ) {
210+ // there is no transaction in ctx, create a new one
211+ tran = await this . beginTransaction ( ) ;
212+ ctx [ this . #connectionStorageKey] = tran ;
160213 } else {
161- ctx . _transactionScopeCount ++ ;
214+ // use transaction in ctx
215+ tran = ctx [ this . #connectionStorageKey] ! ;
216+ tran [ RDSClient . #TRANSACTION_NEST_COUNT] ++ ;
162217 }
163- const tran = ctx . _transactionConnection ;
218+
164219 try {
165220 const result = await scope ( tran ) ;
166- ctx . _transactionScopeCount -- ;
167- if ( ctx . _transactionScopeCount === 0 ) {
168- ctx . _transactionConnection = null ;
221+ tran [ RDSClient . #TRANSACTION_NEST_COUNT] -- ;
222+ if ( tran [ RDSClient . #TRANSACTION_NEST_COUNT] === 0 ) {
223+ ctx [ this . #connectionStorageKey] = null ;
224+ await tran . rollback ( ) ;
169225 }
170226 return result ;
171227 } catch ( err ) {
172- if ( ctx . _transactionConnection ) {
173- ctx . _transactionConnection = null ;
228+ if ( ctx [ this . #connectionStorageKey] ) {
229+ ctx [ this . #connectionStorageKey] = null ;
230+ await tran . rollback ( ) ;
174231 }
175232 throw err ;
176- } finally {
177- await tran . rollback ( ) ;
178233 }
179234 }
180235
236+ /**
237+ * doomed to be rollbacked after transaction scope
238+ * useful on writing tests which are related with database
239+ *
240+ * @param scope - scope with code
241+ * @return {Object } - scope return result
242+ */
243+ async beginDoomedTransactionScope ( scope : TransactionScope ) : Promise < any > {
244+ let ctx = this . #connectionStorage. getStore ( ) ;
245+ if ( ctx ) {
246+ return await this . #beginDoomedTransactionScope( scope , ctx ) ;
247+ }
248+ ctx = { } ;
249+ return await this . #connectionStorage. run ( ctx , async ( ) => {
250+ return await this . #beginDoomedTransactionScope( scope , ctx ! ) ;
251+ } ) ;
252+ }
253+
181254 async end ( ) {
182255 await this . #pool. end ( ) ;
183256 }
0 commit comments