Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion tmva/tmva/inc/TMVA/MethodSVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
#ifndef ROOT_TVectorD
#include "TVectorD.h"
#endif
#ifndef ROOT_TMVA_SVKernelFunction
#include "TMVA/SVKernelFunction.h"
#endif
#endif

namespace TMVA
Expand All @@ -74,9 +77,18 @@ namespace TMVA

virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );

// optimise tuning parameters
virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="Minuit");
virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
std::vector<TMVA::SVKernelFunction::EKernelType> MakeKernelList(std::string multiKernels, TString kernel);
std::map< TString,std::vector<Double_t> > GetTuningOptions();

// training method
void Train( void );

// revoke training (required for optimise tuning parameters)
void Reset( void );

using MethodBase::ReadWeightsFromStream;

// write weights to file
Expand All @@ -97,6 +109,17 @@ namespace TMVA
// ranking of input variables
const Ranking* CreateRanking() { return 0; }

// for SVM optimisation
void SetGamma(Double_t g){fGamma = g;}
void SetCost(Double_t c){fCost = c;}
void SetMGamma(std::string & mg);
void SetOrder(Double_t o){fOrder = o;}
void SetTheta(Double_t t){fTheta = t;}
void SetKappa(Double_t k){fKappa = k;}
void SetMult(Double_t m){fMult = m;}

void GetMGamma(const std::vector<float> & gammas);

protected:

// make ROOT-independent C++ class for classifier response (classifier-specific implementation)
Expand All @@ -111,6 +134,7 @@ namespace TMVA
void DeclareOptions();
void DeclareCompatibilityOptions();
void ProcessOptions();
Double_t getLoss( TString lossFunction );

Float_t fCost; // cost value
Float_t fTolerance; // tolerance parameter
Expand All @@ -126,12 +150,23 @@ namespace TMVA
TVectorD* fMinVars; // for normalization //is it still needed??
TVectorD* fMaxVars; // for normalization //is it still needed??

// for backward compatibility
// for kernel functions
TString fTheKernel; // kernel name
Float_t fDoubleSigmaSquared; // for RBF Kernel
Int_t fOrder; // for Polynomial Kernel ( polynomial order )
Float_t fTheta; // for Sigmoidal Kernel
Float_t fKappa; // for Sigmoidal Kernel
Float_t fMult;
std::vector<Float_t> fmGamma; // vector of gammas for multi-gaussian kernel
Float_t fNumVars; // number of input variables for multi-gaussian
std::vector<TString> fVarNames;
std::string fGammas;
std::string fGammaList;
std::string fTune; // Specify parameters to be tuned
std::string fMultiKernels;

Int_t fDataSize;
TString fLoss;

ClassDef(MethodSVM,0) // Support Vector Machine
};
Expand Down
12 changes: 10 additions & 2 deletions tmva/tmva/inc/TMVA/SVKernelFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,33 @@ namespace TMVA {

public:

enum EKernelType { kLinear , kRBF, kPolynomial, kSigmoidal, kMultiGauss, kProd, kSum};

SVKernelFunction();
SVKernelFunction( Float_t );
SVKernelFunction( EKernelType, Float_t, Float_t=0);
SVKernelFunction( std::vector<float> params );
SVKernelFunction(EKernelType k, std::vector<EKernelType> kernels, std::vector<Float_t> gammas, Float_t gamma, Float_t order, Float_t theta);
~SVKernelFunction();

Float_t Evaluate( SVEvent* ev1, SVEvent* ev2 );

enum EKernelType { kLinear , kRBF, kPolynomial, kSigmoidal };

void setCompatibilityParams(EKernelType k, UInt_t order, Float_t theta, Float_t kappa);

private:

Float_t fGamma; // documentation

// vector of gammas for multidimensional gaussian
std::vector<Float_t> fmGamma;

// kernel, order, theta, and kappa are for backward compatibility
EKernelType fKernel;
UInt_t fOrder;
Float_t fTheta;
Float_t fKappa;

std::vector<EKernelType> fKernelsList;
};
}

Expand Down
Loading