@@ -192,6 +192,48 @@ static double evaluateOperation(Operation* op, double* args) {
192192 return op->evaluate (args, dummyVariables);
193193}
194194
195+ void CompiledExpression::findPowerGroups (vector<vector<int > >& groups, vector<vector<int > >& groupPowers, vector<int >& stepGroup) {
196+ // Identify every step that raises an argument to an integer power.
197+
198+ vector<int > stepPower (operation.size (), 0 );
199+ vector<int > stepArg (operation.size (), -1 );
200+ for (int step = 0 ; step < operation.size (); step++) {
201+ Operation& op = *operation[step];
202+ int power = 0 ;
203+ if (op.getId () == Operation::SQUARE)
204+ power = 2 ;
205+ else if (op.getId () == Operation::CUBE)
206+ power = 3 ;
207+ else if (op.getId () == Operation::POWER_CONSTANT) {
208+ double realPower = dynamic_cast <const Operation::PowerConstant*>(&op)->getValue ();
209+ if (realPower == (int ) realPower)
210+ power = (int ) realPower;
211+ }
212+ if (power != 0 ) {
213+ stepPower[step] = power;
214+ stepArg[step] = arguments[step][0 ];
215+ }
216+ }
217+
218+ // Find groups that operate on the same argument and whose powers have the same sign.
219+
220+ stepGroup.resize (operation.size (), -1 );
221+ for (int i = 0 ; i < operation.size (); i++) {
222+ if (stepGroup[i] != -1 )
223+ continue ;
224+ vector<int > group, power;
225+ for (int j = i; j < operation.size (); j++) {
226+ if (stepArg[i] == stepArg[j] && stepPower[i]*stepPower[j] > 0 ) {
227+ stepGroup[j] = groups.size ();
228+ group.push_back (j);
229+ power.push_back (stepPower[j]);
230+ }
231+ }
232+ groups.push_back (group);
233+ groupPowers.push_back (power);
234+ }
235+ }
236+
195237#if defined(__ARM__) || defined(__ARM64__)
196238void CompiledExpression::generateJitCode () {
197239 CodeHolder code;
@@ -203,6 +245,9 @@ void CompiledExpression::generateJitCode() {
203245 workspaceVar[i] = c.newVecD ();
204246 arm::Gp argsPointer = c.newIntPtr ();
205247 c.mov (argsPointer, imm (&argValues[0 ]));
248+ vector<vector<int > > groups, groupPowers;
249+ vector<int > stepGroup;
250+ findPowerGroups (groups, groupPowers, stepGroup);
206251
207252 // Load the arguments into variables.
208253
@@ -233,6 +278,12 @@ void CompiledExpression::generateJitCode() {
233278 value = 1.0 ;
234279 else if (op.getId () == Operation::DELTA)
235280 value = 1.0 ;
281+ else if (op.getId () == Operation::POWER_CONSTANT) {
282+ if (stepGroup[step] == -1 )
283+ value = dynamic_cast <Operation::PowerConstant&>(op).getValue ();
284+ else
285+ value = 1.0 ;
286+ }
236287 else
237288 continue ;
238289
@@ -260,10 +311,54 @@ void CompiledExpression::generateJitCode() {
260311 c.ldr (constantVar[i], arm::ptr (constantsPointer, 8 *i));
261312 }
262313 }
263-
314+
264315 // Evaluate the operations.
265-
316+
317+ vector<bool > hasComputedPower (operation.size (), false );
266318 for (int step = 0 ; step < (int ) operation.size (); step++) {
319+ if (hasComputedPower[step])
320+ continue ;
321+
322+ // When one or more steps involve raising the same argument to multiple integer
323+ // powers, we can compute them all together for efficiency.
324+
325+ if (stepGroup[step] != -1 ) {
326+ vector<int >& group = groups[stepGroup[step]];
327+ vector<int >& powers = groupPowers[stepGroup[step]];
328+ arm::Vec multiplier = c.newVecD ();
329+ if (powers[0 ] > 0 )
330+ c.fmov (multiplier, workspaceVar[arguments[step][0 ]]);
331+ else {
332+ c.fdiv (multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0 ]]);
333+ for (int i = 0 ; i < powers.size (); i++)
334+ powers[i] = -powers[i];
335+ }
336+ vector<bool > hasAssigned (group.size (), false );
337+ bool done = false ;
338+ while (!done) {
339+ done = true ;
340+ for (int i = 0 ; i < group.size (); i++) {
341+ if (powers[i]%2 == 1 ) {
342+ if (!hasAssigned[i])
343+ c.fmov (workspaceVar[target[group[i]]], multiplier);
344+ else
345+ c.fmul (workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
346+ hasAssigned[i] = true ;
347+ }
348+ powers[i] >>= 1 ;
349+ if (powers[i] != 0 )
350+ done = false ;
351+ }
352+ if (!done)
353+ c.fmul (multiplier, multiplier, multiplier);
354+ }
355+ for (int step : group)
356+ hasComputedPower[step] = true ;
357+ continue ;
358+ }
359+
360+ // Evaluate the step.
361+
267362 Operation& op = *operation[step];
268363 vector<int > args = arguments[step];
269364 if (args.size () == 1 ) {
@@ -360,6 +455,9 @@ void CompiledExpression::generateJitCode() {
360455 case Operation::MULTIPLY_CONSTANT:
361456 c.fmul (workspaceVar[target[step]], workspaceVar[args[0 ]], constantVar[operationConstantIndex[step]]);
362457 break ;
458+ case Operation::POWER_CONSTANT:
459+ generateTwoArgCall (c, workspaceVar[target[step]], workspaceVar[args[0 ]], constantVar[operationConstantIndex[step]], pow);
460+ break ;
363461 case Operation::ABS:
364462 c.fabs (workspaceVar[target[step]], workspaceVar[args[0 ]]);
365463 break ;
@@ -418,7 +516,10 @@ void CompiledExpression::generateJitCode() {
418516 workspaceVar[i] = c.newXmmSd ();
419517 x86::Gp argsPointer = c.newIntPtr ();
420518 c.mov (argsPointer, imm (&argValues[0 ]));
421-
519+ vector<vector<int > > groups, groupPowers;
520+ vector<int > stepGroup;
521+ findPowerGroups (groups, groupPowers, stepGroup);
522+
422523 // Load the arguments into variables.
423524
424525 for (set<string>::const_iterator iter = variableNames.begin (); iter != variableNames.end (); ++iter) {
@@ -448,6 +549,12 @@ void CompiledExpression::generateJitCode() {
448549 value = 1.0 ;
449550 else if (op.getId () == Operation::DELTA)
450551 value = 1.0 ;
552+ else if (op.getId () == Operation::POWER_CONSTANT) {
553+ if (stepGroup[step] == -1 )
554+ value = dynamic_cast <Operation::PowerConstant&>(op).getValue ();
555+ else
556+ value = 1.0 ;
557+ }
451558 else
452559 continue ;
453560
@@ -478,7 +585,52 @@ void CompiledExpression::generateJitCode() {
478585
479586 // Evaluate the operations.
480587
588+ vector<bool > hasComputedPower (operation.size (), false );
481589 for (int step = 0 ; step < (int ) operation.size (); step++) {
590+ if (hasComputedPower[step])
591+ continue ;
592+
593+ // When one or more steps involve raising the same argument to multiple integer
594+ // powers, we can compute them all together for efficiency.
595+
596+ if (stepGroup[step] != -1 ) {
597+ vector<int >& group = groups[stepGroup[step]];
598+ vector<int >& powers = groupPowers[stepGroup[step]];
599+ x86::Xmm multiplier = c.newXmmSd ();
600+ if (powers[0 ] > 0 )
601+ c.movsd (multiplier, workspaceVar[arguments[step][0 ]]);
602+ else {
603+ c.movsd (multiplier, constantVar[operationConstantIndex[step]]);
604+ c.divsd (multiplier, workspaceVar[arguments[step][0 ]]);
605+ for (int i = 0 ; i < powers.size (); i++)
606+ powers[i] = -powers[i];
607+ }
608+ vector<bool > hasAssigned (group.size (), false );
609+ bool done = false ;
610+ while (!done) {
611+ done = true ;
612+ for (int i = 0 ; i < group.size (); i++) {
613+ if (powers[i]%2 == 1 ) {
614+ if (!hasAssigned[i])
615+ c.movsd (workspaceVar[target[group[i]]], multiplier);
616+ else
617+ c.mulsd (workspaceVar[target[group[i]]], multiplier);
618+ hasAssigned[i] = true ;
619+ }
620+ powers[i] >>= 1 ;
621+ if (powers[i] != 0 )
622+ done = false ;
623+ }
624+ if (!done)
625+ c.mulsd (multiplier, multiplier);
626+ }
627+ for (int step : group)
628+ hasComputedPower[step] = true ;
629+ continue ;
630+ }
631+
632+ // Evaluate the step.
633+
482634 Operation& op = *operation[step];
483635 vector<int > args = arguments[step];
484636 if (args.size () == 1 ) {
@@ -587,6 +739,9 @@ void CompiledExpression::generateJitCode() {
587739 c.movsd (workspaceVar[target[step]], workspaceVar[args[0 ]]);
588740 c.mulsd (workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
589741 break ;
742+ case Operation::POWER_CONSTANT:
743+ generateTwoArgCall (c, workspaceVar[target[step]], workspaceVar[args[0 ]], constantVar[operationConstantIndex[step]], pow);
744+ break ;
590745 case Operation::ABS:
591746 generateSingleArgCall (c, workspaceVar[target[step]], workspaceVar[args[0 ]], fabs);
592747 break ;
0 commit comments