SHOGUN  v1.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CrossValidation.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2011 Heiko Strathmann
8  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9  */
10 
12 #include <shogun/machine/Machine.h>
15 #include <shogun/base/Parameter.h>
17 
18 using namespace shogun;
19 
21 {
22  init();
23 }
24 
26  CLabels* labels, CSplittingStrategy* splitting_strategy,
27  CEvaluation* evaluation_criterium)
28 {
29  init();
30 
31  m_machine=machine;
32  m_features=features;
33  m_labels=labels;
34  m_splitting_strategy=splitting_strategy;
35  m_evaluation_criterium=evaluation_criterium;
36 
37  SG_REF(m_machine);
38  SG_REF(m_features);
39  SG_REF(m_labels);
40  SG_REF(m_splitting_strategy);
41  SG_REF(m_evaluation_criterium);
42 }
43 
45 {
46  SG_UNREF(m_machine);
47  SG_UNREF(m_features);
48  SG_UNREF(m_labels);
49  SG_UNREF(m_splitting_strategy);
50  SG_UNREF(m_evaluation_criterium);
51 }
52 
54 {
55  return m_evaluation_criterium->get_evaluation_direction();
56 }
57 
58 void CCrossValidation::init()
59 {
60  m_machine=NULL;
61  m_features=NULL;
62  m_labels=NULL;
63  m_splitting_strategy=NULL;
64  m_evaluation_criterium=NULL;
65  m_num_runs=1;
66  m_conf_int_alpha=0;
67 
68  m_parameters->add((CSGObject**) &m_machine, "machine",
69  "Used learning machine");
70  m_parameters->add((CSGObject**) &m_features, "features", "Used features");
71  m_parameters->add((CSGObject**) &m_labels, "labels", "Used labels");
72  m_parameters->add((CSGObject**) &m_splitting_strategy,
73  "splitting_strategy", "Used splitting strategy");
74  m_parameters->add((CSGObject**) &m_evaluation_criterium,
75  "evaluation_criterium", "Used evaluation criterium");
76  m_parameters->add(&m_num_runs, "num_runs", "Number of repetitions");
77  m_parameters->add(&m_conf_int_alpha, "conf_int_alpha", "alpha-value of confidence "
78  "interval");
79 }
80 
82 {
83  SG_REF(m_machine);
84  return m_machine;
85 }
86 
88 {
89  SGVector<float64_t> results(m_num_runs);
90 
91  for (index_t i=0; i<m_num_runs; ++i)
92  results.vector[i]=evaluate_one_run();
93 
94  /* construct evaluation result */
95  CrossValidationResult result;
96  result.has_conf_int=m_conf_int_alpha!=0;
97  result.conf_int_alpha=m_conf_int_alpha;
98 
99  if (result.has_conf_int)
100  {
101  result.conf_int_alpha=m_conf_int_alpha;
103  result.conf_int_alpha, result.conf_int_low, result.conf_int_up);
104  }
105  else
106  {
107  result.mean=CStatistics::mean(results);
108  result.conf_int_low=0;
109  result.conf_int_up=0;
110  }
111 
112  SG_FREE(results.vector);
113 
114  return result;
115 }
116 
118 {
119  if (conf_int_alpha<0 || conf_int_alpha>=1)
120  {
121  SG_ERROR("%f is an illegal alpha-value for confidence interval of "
122  "cross-validation\n", conf_int_alpha);
123  }
124 
125  m_conf_int_alpha=conf_int_alpha;
126 }
127 
128 void CCrossValidation::set_num_runs(int32_t num_runs)
129 {
130  if (num_runs<1)
131  SG_ERROR("%d is an illegal number of repetitions\n", num_runs);
132 
133  m_num_runs=num_runs;
134 }
135 
137 {
138  index_t num_subsets=m_splitting_strategy->get_num_subsets();
139  float64_t* results=SG_MALLOC(float64_t, num_subsets);
140 
141  /* set labels to machine */
142  m_machine->set_labels(m_labels);
143 
144  /* tell machine to store model internally
145  * (otherwise changing subset of features will kaboom the classifier) */
146  m_machine->set_store_model_features(true);
147 
148  /* do actual cross-validation */
149  for (index_t i=0; i<num_subsets; ++i)
150  {
151  /* set feature subset for training */
152  SGVector<index_t> inverse_subset_indices=
153  m_splitting_strategy->generate_subset_inverse(i);
154  m_features->set_subset(new CSubset(inverse_subset_indices));
155 
156  /* set label subset for training (copy data before) */
157  SGVector<index_t> inverse_subset_indices_copy(
158  inverse_subset_indices.vlen);
159  memcpy(inverse_subset_indices_copy.vector,
160  inverse_subset_indices.vector,
161  inverse_subset_indices.vlen*sizeof(index_t));
162  m_labels->set_subset(new CSubset(inverse_subset_indices_copy));
163 
164  /* train machine on training features */
165  m_machine->train(m_features);
166 
167  /* set feature subset for testing (subset method that stores pointer) */
168  SGVector<index_t> subset_indices=
169  m_splitting_strategy->generate_subset_indices(i);
170  m_features->set_subset(new CSubset(subset_indices));
171 
172  /* apply machine to test features */
173  CLabels* result_labels=m_machine->apply(m_features);
174  SG_REF(result_labels);
175 
176  /* set label subset for testing (copy data before) */
177  SGVector<index_t> subset_indices_copy(subset_indices.vlen);
178  memcpy(subset_indices_copy.vector, subset_indices.vector,
179  subset_indices.vlen*sizeof(index_t));
180  m_labels->set_subset(new CSubset(subset_indices_copy));
181 
182  /* evaluate */
183  results[i]=m_evaluation_criterium->evaluate(result_labels, m_labels);
184 
185  /* clean up, reset subsets */
186  SG_UNREF(result_labels);
187  m_features->remove_subset();
188  m_labels->remove_subset();
189  }
190 
191  /* build arithmetic mean of results */
192  float64_t mean=CStatistics::mean(SGVector<float64_t>(results, num_subsets));
193 
194  /* clean up */
195  SG_FREE(results);
196 
197  return mean;
198 }
class for adding subset support to a class. Provides an interface for getting/setting subset_matrices...
Definition: Subset.h:24
index_t get_num_subsets() const
virtual EEvaluationDirection get_evaluation_direction()=0
virtual CLabels * apply()=0
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:35
void set_conf_int_alpha(float64_t m_conf_int_alpha)
static float64_t confidence_intervals_mean(SGVector< float64_t > values, float64_t alpha, float64_t &conf_int_low, float64_t &conf_int_up)
Definition: Statistics.cpp:52
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)=0
CrossValidationResult evaluate()
Abstract base class for all splitting types. Takes a CLabels instance and generates a desired number ...
#define SG_ERROR(...)
Definition: SGIO.h:75
Parameter * m_parameters
Definition: SGObject.h:297
#define SG_REF(x)
Definition: SGObject.h:44
void set_num_runs(int32_t num_runs)
A generic learning machine interface.
Definition: Machine.h:96
type to encapsulate the results of an evaluation run. May contain confidence interval (if conf_int_al...
CMachine * get_machine() const
void add(bool *param, const char *name, const char *description="")
Definition: Parameter.cpp:23
virtual void set_subset(CSubset *subset)
Definition: Labels.cpp:245
virtual void set_store_model_features(bool store_model)
Definition: Machine.cpp:109
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:76
EEvaluationDirection
Definition: Evaluation.h:24
double float64_t
Definition: common.h:56
virtual void remove_subset()
Definition: Labels.cpp:252
SGVector< index_t > generate_subset_inverse(index_t subset_idx)
#define SG_FREE(ptr)
Definition: memory.h:39
virtual float64_t evaluate_one_run()
#define SG_UNREF(x)
Definition: SGObject.h:45
SGVector< index_t > generate_subset_indices(index_t subset_idx)
virtual void remove_subset()
Definition: Features.cpp:370
EEvaluationDirection get_evaluation_direction()
The class Features is the base class of all feature objects.
Definition: Features.h:56
virtual bool train(CFeatures *data=NULL)
Definition: Machine.cpp:35
static float64_t mean(SGVector< float64_t > values)
Definition: Statistics.cpp:21
int32_t index_t
Definition: DataType.h:25
virtual void set_subset(CSubset *subset)
Definition: Features.cpp:352
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:63
#define SG_MALLOC(type, len)
Definition: memory.h:36
Class Evaluation, a base class for other classes used to evaluate labels, e.g. accuracy of classifica...
Definition: Evaluation.h:36
index_t vlen
Definition: DataType.h:248

SHOGUN Machine Learning Toolbox - Documentation