Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/gauxc/xc_integrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,20 @@ class XCIntegrator {
XCIntegrator( XCIntegrator&& ) noexcept;

value_type integrate_den( const MatrixType& );

value_type eval_exc( const MatrixType&, const IntegratorSettingsXC& = IntegratorSettingsXC{} );
value_type eval_exc( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& = IntegratorSettingsXC{} );
value_type eval_exc( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&, const IntegratorSettingsXC& = IntegratorSettingsXC{} );

exc_vxc_type_rks eval_exc_vxc ( const MatrixType&,
const IntegratorSettingsXC& = IntegratorSettingsXC{} );
exc_vxc_type_uks eval_exc_vxc ( const MatrixType&, const MatrixType&,
const IntegratorSettingsXC& = IntegratorSettingsXC{} );
exc_vxc_type_gks eval_exc_vxc ( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&,
const IntegratorSettingsXC& = IntegratorSettingsXC{});

exc_grad_type eval_exc_grad( const MatrixType& );

exx_type eval_exx ( const MatrixType&,
const IntegratorSettingsEXX& = IntegratorSettingsEXX{} );

Expand Down
21 changes: 21 additions & 0 deletions include/gauxc/xc_integrator/impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ typename XCIntegrator<MatrixType>::value_type
return pimpl_->integrate_den(P);
};

template <typename MatrixType>
typename XCIntegrator<MatrixType>::value_type
XCIntegrator<MatrixType>::eval_exc( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) {
if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
return pimpl_->eval_exc(P, ks_settings);
}

template <typename MatrixType>
typename XCIntegrator<MatrixType>::value_type
XCIntegrator<MatrixType>::eval_exc( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) {
if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
return pimpl_->eval_exc(Ps, Pz, ks_settings);
}

template <typename MatrixType>
typename XCIntegrator<MatrixType>::value_type
XCIntegrator<MatrixType>::eval_exc( const MatrixType& Ps, const MatrixType& Pz, const MatrixType& Py, const MatrixType& Px, const IntegratorSettingsXC& ks_settings ) {
if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
return pimpl_->eval_exc(Ps, Pz, Py, Px, ks_settings);
}

template <typename MatrixType>
typename XCIntegrator<MatrixType>::exc_vxc_type_rks
XCIntegrator<MatrixType>::eval_exc_vxc( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) {
Expand Down
38 changes: 38 additions & 0 deletions include/gauxc/xc_integrator/replicated/impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,44 @@ typename ReplicatedXCIntegrator<MatrixType>::value_type
return N_EL;
}

template <typename MatrixType>
typename ReplicatedXCIntegrator<MatrixType>::value_type
ReplicatedXCIntegrator<MatrixType>::eval_exc_( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) {

if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
value_type EXC;

pimpl_->eval_exc( P.rows(), P.cols(), P.data(), P.rows(), &EXC, ks_settings );

return EXC;
}

template <typename MatrixType>
typename ReplicatedXCIntegrator<MatrixType>::value_type
ReplicatedXCIntegrator<MatrixType>::eval_exc_( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) {

if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
value_type EXC;

const size_t n = Ps.rows();
pimpl_->eval_exc( n, n, Ps.data(), n, Pz.data(), n, &EXC, ks_settings );

return EXC;
}

template <typename MatrixType>
typename ReplicatedXCIntegrator<MatrixType>::value_type
ReplicatedXCIntegrator<MatrixType>::eval_exc_( const MatrixType& Ps, const MatrixType& Pz, const MatrixType& Py, const MatrixType& Px, const IntegratorSettingsXC& ks_settings ) {

if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
value_type EXC;

const size_t n = Ps.rows();
pimpl_->eval_exc( n, n, Ps.data(), n, Pz.data(), n, Py.data(), n, Px.data(), n, &EXC, ks_settings );

return EXC;
}

template <typename MatrixType>
typename ReplicatedXCIntegrator<MatrixType>::exc_vxc_type_rks
ReplicatedXCIntegrator<MatrixType>::eval_exc_vxc_( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ class ReplicatedXCIntegratorImpl {

virtual void integrate_den_( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* N_EL ) = 0;

virtual void eval_exc_( int64_t m, int64_t n, const value_type* P, int64_t ldp,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) = 0;
virtual void eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) = 0;
virtual void eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
const value_type* Py, int64_t ldpy,
const value_type* Px, int64_t ldpx,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) = 0;

virtual void eval_exc_vxc_( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* VXC, int64_t ldvxc,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) = 0;
Expand Down Expand Up @@ -81,6 +93,17 @@ class ReplicatedXCIntegratorImpl {
void integrate_den( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* N_EL );

void eval_exc( int64_t m, int64_t n, const value_type* P, int64_t ldp,
value_type* EXC, const IntegratorSettingsXC& ks_settings );
void eval_exc( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
value_type* EXC, const IntegratorSettingsXC& ks_settings );
void eval_exc( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
const value_type* Py, int64_t ldpy,
const value_type* Px, int64_t ldpx,
value_type* EXC, const IntegratorSettingsXC& ks_settings );

void eval_exc_vxc( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* VXC, int64_t ldvxc,
value_type* EXC, const IntegratorSettingsXC& ks_settings );
Expand Down
3 changes: 3 additions & 0 deletions include/gauxc/xc_integrator/replicated_xc_integrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class ReplicatedXCIntegrator : public XCIntegratorImpl<MatrixType> {
std::unique_ptr< pimpl_type > pimpl_;

value_type integrate_den_( const MatrixType& ) override;
value_type eval_exc_ ( const MatrixType&, const IntegratorSettingsXC& ) override;
value_type eval_exc_ ( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override;
value_type eval_exc_ ( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override;
exc_vxc_type_rks eval_exc_vxc_ ( const MatrixType&, const IntegratorSettingsXC& ) override;
exc_vxc_type_uks eval_exc_vxc_ ( const MatrixType&, const MatrixType&, const IntegratorSettingsXC&) override;
exc_vxc_type_gks eval_exc_vxc_ ( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override;
Expand Down
42 changes: 35 additions & 7 deletions include/gauxc/xc_integrator/xc_integrator_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class XCIntegratorImpl {
protected:

virtual value_type integrate_den_( const MatrixType& P ) = 0;

virtual value_type eval_exc_ ( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) = 0;
virtual value_type eval_exc_ ( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) = 0;
virtual value_type eval_exc_ ( const MatrixType& Ps, const MatrixType& Pz, const MatrixType& Py, const MatrixType& Px, const IntegratorSettingsXC& ks_settings ) = 0;

virtual exc_vxc_type_rks eval_exc_vxc_ ( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) = 0;
virtual exc_vxc_type_uks eval_exc_vxc_ ( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) = 0;
virtual exc_vxc_type_gks eval_exc_vxc_ ( const MatrixType& Ps, const MatrixType& Pz, const MatrixType& Py, const MatrixType& Px,
Expand All @@ -54,13 +59,38 @@ class XCIntegratorImpl {
* @param[in] P The density matrix
* @returns Approx Tr[P*S]
*/
value_type integrate_den( const MatrixType& P ) {
return integrate_den_(P);
}
value_type integrate_den( const MatrixType& P ) {
return integrate_den_(P);
}

/** Integrate EXC for RKS
*
* @param[in] P The alpha density matrix
* @returns Integrated EXC
*/
value_type eval_exc( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) {
return eval_exc_(P, ks_settings);
}

/** Integrate EXC for UKS
*
* @param[in] P The alpha density matrix
* @returns Integrated EXC
*/
value_type eval_exc( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) {
return eval_exc_(Ps, Pz, ks_settings);
}

/** Integrate EXC for GKS
*
* @param[in] P The alpha density matrix
* @returns Integrated EXC
*/
value_type eval_exc( const MatrixType& Ps, const MatrixType& Pz, const MatrixType& Py, const MatrixType& Px, const IntegratorSettingsXC& ks_settings ) {
return eval_exc_(Ps, Pz, Py, Px, ks_settings);
}

/** Integrate EXC / VXC (Mean field terms) for RKS
*
* TODO: add API for UKS/GKS
*
* @param[in] P The alpha density matrix
* @returns EXC / VXC in a combined structure
Expand Down Expand Up @@ -89,8 +119,6 @@ class XCIntegratorImpl {
}

/** Integrate Exact Exchange for RHF
*
* TODO: add API for UHF/GHF
*
* @param[in] P The alpha density matrix
* @returns Excact Exchange Matrix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* See LICENSE.txt for details
*/
#include "incore_replicated_xc_device_integrator_integrate_den.hpp"
#include "incore_replicated_xc_device_integrator_exc.hpp"
#include "incore_replicated_xc_device_integrator_exc_vxc.hpp"
#include "incore_replicated_xc_device_integrator_exc_grad.hpp"
#include "incore_replicated_xc_device_integrator_exx.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ class IncoreReplicatedXCDeviceIntegrator :
void integrate_den_( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* N_EL ) override;

void eval_exc_( int64_t m, int64_t n, const value_type* P, int64_t ldp,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) override;
void eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) override;
void eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
const value_type* Py, int64_t ldpy,
const value_type* Px, int64_t ldpx,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) override;

void eval_exc_vxc_( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* VXC, int64_t ldvxc,
value_type* EXC, const IntegratorSettingsXC& settings) override;
Expand Down Expand Up @@ -74,7 +85,7 @@ class IncoreReplicatedXCDeviceIntegrator :

void exc_vxc_local_work_( const basis_type& basis, const value_type* P, int64_t ldp,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data );
XCDeviceData& device_data, bool do_vxc = true );

void exc_vxc_local_work_( const basis_type& basis, const value_type* P, int64_t ldp,
value_type* VXC, int64_t ldvxc, value_type* EXC, value_type *N_EL,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/**
* GauXC Copyright (c) 2020-2024, The Regents of the University of California,
* through Lawrence Berkeley National Laboratory (subject to receipt of
* any required approvals from the U.S. Dept. of Energy). All rights reserved.
*
* See LICENSE.txt for details
*/
#include "incore_replicated_xc_device_integrator.hpp"
#include "device/local_device_work_driver.hpp"
#include "device/xc_device_aos_data.hpp"
#include <fstream>
#include <gauxc/exceptions.hpp>
#include <gauxc/util/unused.hpp>

namespace GauXC {
namespace detail {

template <typename ValueType>
void IncoreReplicatedXCDeviceIntegrator<ValueType>::
eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
const value_type* Py, int64_t ldpy,
const value_type* Px, int64_t ldpx,
value_type* EXC, const IntegratorSettingsXC& settings ) {


if(Pz) GAUXC_GENERIC_EXCEPTION("UKS/GKS + EXC Only Device NYI");
const auto& basis = this->load_balancer_->basis();

// Check that P / VXC are sane
const int64_t nbf = basis.nbf();
if( m != n )
GAUXC_GENERIC_EXCEPTION("P/VXC Must Be Square");
if( m != nbf )
GAUXC_GENERIC_EXCEPTION("P/VXC Must Have Same Dimension as Basis");
if( ldps < nbf )
GAUXC_GENERIC_EXCEPTION("Invalid LDP");


// Get Tasks
auto& tasks = this->load_balancer_->get_tasks();

// Allocate Device memory
auto* lwd = dynamic_cast<LocalDeviceWorkDriver*>(this->local_work_driver_.get() );
auto rt = detail::as_device_runtime(this->load_balancer_->runtime());
auto device_data_ptr = lwd->create_device_data(rt);

GAUXC_MPI_CODE( MPI_Barrier(rt.comm());)

// Temporary electron count to judge integrator accuracy
value_type N_EL;

// Compute local contributions to EXC/VXC and retrieve
// data from device
this->timer_.time_op("XCIntegrator.LocalWork_EXC", [&](){
exc_vxc_local_work_( basis, Ps, ldps, nullptr, 0, EXC,
&N_EL, tasks.begin(), tasks.end(), *device_data_ptr);
});

GAUXC_MPI_CODE(
this->timer_.time_op("XCIntegrator.ImbalanceWait_EXC",[&](){
MPI_Barrier(this->load_balancer_->runtime().comm());
});
)

// Reduce Results in host mem
this->timer_.time_op("XCIntegrator.Allreduce_EXC", [&](){
this->reduction_driver_->allreduce_inplace( EXC, 1 , ReductionOp::Sum );
this->reduction_driver_->allreduce_inplace( &N_EL, 1 , ReductionOp::Sum );
});

}



template <typename ValueType>
void IncoreReplicatedXCDeviceIntegrator<ValueType>::
eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
value_type* EXC, const IntegratorSettingsXC& settings ) {

eval_exc_(m, n, Ps, ldps, Pz, ldpz, nullptr, 0, nullptr, 0, EXC, settings);

}

template <typename ValueType>
void IncoreReplicatedXCDeviceIntegrator<ValueType>::
eval_exc_( int64_t m, int64_t n, const value_type* P, int64_t ldp,
value_type* EXC, const IntegratorSettingsXC& settings ) {

eval_exc_(m, n, P, ldp, nullptr, 0, nullptr, 0, nullptr, 0, EXC, settings);

}


}
}

Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ template <typename ValueType>
void IncoreReplicatedXCDeviceIntegrator<ValueType>::
exc_vxc_local_work_( const basis_type& basis, const value_type* P, int64_t ldp,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data ) {
XCDeviceData& device_data, bool do_vxc ) {


auto* lwd = dynamic_cast<LocalDeviceWorkDriver*>(this->local_work_driver_.get() );
Expand Down Expand Up @@ -195,7 +195,7 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
const auto nbf = basis.nbf();
const auto nshells = basis.nshells();
device_data.reset_allocations();
device_data.allocate_static_data_exc_vxc( nbf, nshells );
device_data.allocate_static_data_exc_vxc( nbf, nshells, do_vxc );
device_data.send_static_data_density_basis( P, ldp, basis );

// Zero integrands
Expand Down Expand Up @@ -257,6 +257,7 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
// Do scalar EXC/N_EL integrations
lwd->inc_exc( &device_data );
lwd->inc_nel( &device_data );
if( not do_vxc ) continue;

// Evaluate Z (+ M) matrix
if( func.is_mgga() ) {
Expand All @@ -272,7 +273,7 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
} // Loop over batches of batches

// Symmetrize VXC in device memory
lwd->symmetrize_vxc( &device_data );
if(do_vxc) lwd->symmetrize_vxc( &device_data );

}

Expand All @@ -287,7 +288,8 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
XCDeviceData& device_data ) {

// Get integrate and keep data on device
exc_vxc_local_work_( basis, P, ldp, task_begin, task_end, device_data );
const bool do_vxc = VXC;
exc_vxc_local_work_( basis, P, ldp, task_begin, task_end, device_data, do_vxc );
auto rt = detail::as_device_runtime(this->load_balancer_->runtime());
rt.device_backend()->master_queue_synchronize();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/
#include "shell_batched_replicated_xc_device_integrator.hpp"
#include "shell_batched_replicated_xc_integrator_integrate_den.hpp"
#include "shell_batched_replicated_xc_integrator_exc.hpp"
#include "shell_batched_replicated_xc_integrator_exc_vxc.hpp"
#include "shell_batched_replicated_xc_integrator_exc_grad.hpp"
#include "shell_batched_replicated_xc_integrator_exx.hpp"
Expand Down
Loading