новите find_finite
и find_nonfinite
функции в Armadillo 4.300 са страхотни допълнения! В моите тестове с Rcpp
обаче те са около 2,5 пъти по-бавни в сравнение със стандартен цикъл. По-долу е даден код за изчисляване на сумата и средната стойност с изтриване на главни и малки букви, съответстващо на опцията na.rm=TRUE
на R. Сравнителните показатели за производителност от R показват, че първата версия (sum_arma
и mean_arma
) е около 3,5 пъти по-бърза в сравнение с цикъла. Правя ли всичко правилно? Някакъв начин за подобряване на производителността?
C++ код
#include <numeric>
#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::export]]
double sum_arma1(arma::mat& X) {
double sum = 0;
for (int i = 0; i < X.size(); ++i) {
if (arma::is_finite(X(i)))
sum += X(i);
}
return sum;
}
// [[Rcpp::export]]
double sum_arma2(arma::mat& X) {
return arma::sum(X.elem(arma::find_finite(X)));
}
// [[Rcpp::export]]
double mean_arma1(arma::mat& X) {
double sum = 0;
int n = 0;
for (int i = 0; i < X.size(); ++i) {
if (arma::is_finite(X(i))) {
sum += X(i);
n += 1;
}
}
return sum/n;
}
// [[Rcpp::export]]
double mean_arma2(arma::mat& X) {
return arma::mean(X.elem(arma::find_finite(X)));
}
Сравнителни резултати от R
# data
X = matrix(rnorm(1e6),1000,1000)
X[sample(1:1000,100),sample(1:1000,100)] = NA
# equal?
all.equal(sum(X, na.rm=TRUE),sum_arma1(X))
all.equal(sum(X, na.rm=TRUE),sum_arma2(X))
all.equal(mean(X, na.rm=TRUE),mean_arma1(X))
all.equal(mean(X, na.rm=TRUE),mean_arma2(X))
# benchmark
benchmark(
sum(X, na.rm=TRUE),
sum_arma1(X),
sum_arma2(X),
replications=100)
# test replications elapsed relative user.self sys.self
# 2 sum_arma1(X) 100 0.259 1.000 0.259 0.001
# 3 sum_arma2(X) 100 1.035 3.996 0.750 0.293
# 1 sum(X, na.rm = TRUE) 100 0.491 1.896 0.492 0.003
benchmark(
mean(X, na.rm=TRUE),
mean_arma1(X),
mean_arma2(X),
replications=100)
# test replications elapsed relative user.self sys.self
# 2 mean_arma1(X) 100 0.252 1.00 0.253 0.001
# 3 mean_arma2(X) 100 0.819 3.25 0.620 0.206
# 1 mean(X, na.rm = TRUE) 100 7.440 29.52 7.120 0.373