Skip to content

Commit e1a926f

Browse files
authored
Optimized computing powers in CompiledExpression (openmm#3520)
* Optimized computing powers in CompiledExpression * Fixed compilation error * Attempt at fixing compilation error
1 parent cc7018e commit e1a926f

File tree

3 files changed

+160
-3
lines changed

3 files changed

+160
-3
lines changed

devtools/ci/gh-actions/conda-envs/build-ubuntu-latest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ dependencies:
77
- cmake
88
- make
99
- ccache
10+
- sysroot_linux-64 2.17
1011
# host
1112
- python
1213
- cython

libraries/lepton/include/lepton/CompiledExpression.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class LEPTON_EXPORT CompiledExpression {
105105
std::map<std::string, double> dummyVariables;
106106
double (*jitCode)();
107107
#ifdef LEPTON_USE_JIT
108+
void findPowerGroups(std::vector<std::vector<int> >& groups, std::vector<std::vector<int> >& groupPowers, std::vector<int>& stepGroup);
108109
void generateJitCode();
109110
#if defined(__ARM__) || defined(__ARM64__)
110111
void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, double (*function)(double));

libraries/lepton/src/CompiledExpression.cpp

Lines changed: 158 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,48 @@ static double evaluateOperation(Operation* op, double* args) {
192192
return op->evaluate(args, dummyVariables);
193193
}
194194

195+
void CompiledExpression::findPowerGroups(vector<vector<int> >& groups, vector<vector<int> >& groupPowers, vector<int>& stepGroup) {
196+
// Identify every step that raises an argument to an integer power.
197+
198+
vector<int> stepPower(operation.size(), 0);
199+
vector<int> stepArg(operation.size(), -1);
200+
for (int step = 0; step < operation.size(); step++) {
201+
Operation& op = *operation[step];
202+
int power = 0;
203+
if (op.getId() == Operation::SQUARE)
204+
power = 2;
205+
else if (op.getId() == Operation::CUBE)
206+
power = 3;
207+
else if (op.getId() == Operation::POWER_CONSTANT) {
208+
double realPower = dynamic_cast<const Operation::PowerConstant*>(&op)->getValue();
209+
if (realPower == (int) realPower)
210+
power = (int) realPower;
211+
}
212+
if (power != 0) {
213+
stepPower[step] = power;
214+
stepArg[step] = arguments[step][0];
215+
}
216+
}
217+
218+
// Find groups that operate on the same argument and whose powers have the same sign.
219+
220+
stepGroup.resize(operation.size(), -1);
221+
for (int i = 0; i < operation.size(); i++) {
222+
if (stepGroup[i] != -1)
223+
continue;
224+
vector<int> group, power;
225+
for (int j = i; j < operation.size(); j++) {
226+
if (stepArg[i] == stepArg[j] && stepPower[i]*stepPower[j] > 0) {
227+
stepGroup[j] = groups.size();
228+
group.push_back(j);
229+
power.push_back(stepPower[j]);
230+
}
231+
}
232+
groups.push_back(group);
233+
groupPowers.push_back(power);
234+
}
235+
}
236+
195237
#if defined(__ARM__) || defined(__ARM64__)
196238
void CompiledExpression::generateJitCode() {
197239
CodeHolder code;
@@ -203,6 +245,9 @@ void CompiledExpression::generateJitCode() {
203245
workspaceVar[i] = c.newVecD();
204246
arm::Gp argsPointer = c.newIntPtr();
205247
c.mov(argsPointer, imm(&argValues[0]));
248+
vector<vector<int> > groups, groupPowers;
249+
vector<int> stepGroup;
250+
findPowerGroups(groups, groupPowers, stepGroup);
206251

207252
// Load the arguments into variables.
208253

@@ -233,6 +278,12 @@ void CompiledExpression::generateJitCode() {
233278
value = 1.0;
234279
else if (op.getId() == Operation::DELTA)
235280
value = 1.0;
281+
else if (op.getId() == Operation::POWER_CONSTANT) {
282+
if (stepGroup[step] == -1)
283+
value = dynamic_cast<Operation::PowerConstant&>(op).getValue();
284+
else
285+
value = 1.0;
286+
}
236287
else
237288
continue;
238289

@@ -260,10 +311,54 @@ void CompiledExpression::generateJitCode() {
260311
c.ldr(constantVar[i], arm::ptr(constantsPointer, 8*i));
261312
}
262313
}
263-
314+
264315
// Evaluate the operations.
265-
316+
317+
vector<bool> hasComputedPower(operation.size(), false);
266318
for (int step = 0; step < (int) operation.size(); step++) {
319+
if (hasComputedPower[step])
320+
continue;
321+
322+
// When one or more steps involve raising the same argument to multiple integer
323+
// powers, we can compute them all together for efficiency.
324+
325+
if (stepGroup[step] != -1) {
326+
vector<int>& group = groups[stepGroup[step]];
327+
vector<int>& powers = groupPowers[stepGroup[step]];
328+
arm::Vec multiplier = c.newVecD();
329+
if (powers[0] > 0)
330+
c.fmov(multiplier, workspaceVar[arguments[step][0]]);
331+
else {
332+
c.fdiv(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]);
333+
for (int i = 0; i < powers.size(); i++)
334+
powers[i] = -powers[i];
335+
}
336+
vector<bool> hasAssigned(group.size(), false);
337+
bool done = false;
338+
while (!done) {
339+
done = true;
340+
for (int i = 0; i < group.size(); i++) {
341+
if (powers[i]%2 == 1) {
342+
if (!hasAssigned[i])
343+
c.fmov(workspaceVar[target[group[i]]], multiplier);
344+
else
345+
c.fmul(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
346+
hasAssigned[i] = true;
347+
}
348+
powers[i] >>= 1;
349+
if (powers[i] != 0)
350+
done = false;
351+
}
352+
if (!done)
353+
c.fmul(multiplier, multiplier, multiplier);
354+
}
355+
for (int step : group)
356+
hasComputedPower[step] = true;
357+
continue;
358+
}
359+
360+
// Evaluate the step.
361+
267362
Operation& op = *operation[step];
268363
vector<int> args = arguments[step];
269364
if (args.size() == 1) {
@@ -360,6 +455,9 @@ void CompiledExpression::generateJitCode() {
360455
case Operation::MULTIPLY_CONSTANT:
361456
c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
362457
break;
458+
case Operation::POWER_CONSTANT:
459+
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
460+
break;
363461
case Operation::ABS:
364462
c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]);
365463
break;
@@ -418,7 +516,10 @@ void CompiledExpression::generateJitCode() {
418516
workspaceVar[i] = c.newXmmSd();
419517
x86::Gp argsPointer = c.newIntPtr();
420518
c.mov(argsPointer, imm(&argValues[0]));
421-
519+
vector<vector<int> > groups, groupPowers;
520+
vector<int> stepGroup;
521+
findPowerGroups(groups, groupPowers, stepGroup);
522+
422523
// Load the arguments into variables.
423524

424525
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
@@ -448,6 +549,12 @@ void CompiledExpression::generateJitCode() {
448549
value = 1.0;
449550
else if (op.getId() == Operation::DELTA)
450551
value = 1.0;
552+
else if (op.getId() == Operation::POWER_CONSTANT) {
553+
if (stepGroup[step] == -1)
554+
value = dynamic_cast<Operation::PowerConstant&>(op).getValue();
555+
else
556+
value = 1.0;
557+
}
451558
else
452559
continue;
453560

@@ -478,7 +585,52 @@ void CompiledExpression::generateJitCode() {
478585

479586
// Evaluate the operations.
480587

588+
vector<bool> hasComputedPower(operation.size(), false);
481589
for (int step = 0; step < (int) operation.size(); step++) {
590+
if (hasComputedPower[step])
591+
continue;
592+
593+
// When one or more steps involve raising the same argument to multiple integer
594+
// powers, we can compute them all together for efficiency.
595+
596+
if (stepGroup[step] != -1) {
597+
vector<int>& group = groups[stepGroup[step]];
598+
vector<int>& powers = groupPowers[stepGroup[step]];
599+
x86::Xmm multiplier = c.newXmmSd();
600+
if (powers[0] > 0)
601+
c.movsd(multiplier, workspaceVar[arguments[step][0]]);
602+
else {
603+
c.movsd(multiplier, constantVar[operationConstantIndex[step]]);
604+
c.divsd(multiplier, workspaceVar[arguments[step][0]]);
605+
for (int i = 0; i < powers.size(); i++)
606+
powers[i] = -powers[i];
607+
}
608+
vector<bool> hasAssigned(group.size(), false);
609+
bool done = false;
610+
while (!done) {
611+
done = true;
612+
for (int i = 0; i < group.size(); i++) {
613+
if (powers[i]%2 == 1) {
614+
if (!hasAssigned[i])
615+
c.movsd(workspaceVar[target[group[i]]], multiplier);
616+
else
617+
c.mulsd(workspaceVar[target[group[i]]], multiplier);
618+
hasAssigned[i] = true;
619+
}
620+
powers[i] >>= 1;
621+
if (powers[i] != 0)
622+
done = false;
623+
}
624+
if (!done)
625+
c.mulsd(multiplier, multiplier);
626+
}
627+
for (int step : group)
628+
hasComputedPower[step] = true;
629+
continue;
630+
}
631+
632+
// Evaluate the step.
633+
482634
Operation& op = *operation[step];
483635
vector<int> args = arguments[step];
484636
if (args.size() == 1) {
@@ -587,6 +739,9 @@ void CompiledExpression::generateJitCode() {
587739
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
588740
c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
589741
break;
742+
case Operation::POWER_CONSTANT:
743+
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
744+
break;
590745
case Operation::ABS:
591746
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs);
592747
break;

0 commit comments

Comments
 (0)