SHOGUN  v1.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
SVMLin.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2006-2009 Soeren Sonnenburg
8  * Copyright (C) 2006-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
12 #include <shogun/features/Labels.h>
17 #include <shogun/features/Labels.h>
18 
19 using namespace shogun;
20 
22 : CLinearMachine(), C1(1), C2(1), epsilon(1e-5), use_bias(true)
23 {
24 }
25 
27  float64_t C, CDotFeatures* traindat, CLabels* trainlab)
28 : CLinearMachine(), C1(C), C2(C), epsilon(1e-5), use_bias(true)
29 {
30  set_features(traindat);
31  set_labels(trainlab);
32 }
33 
34 
36 {
37 }
38 
40 {
41  ASSERT(labels);
42 
43  if (data)
44  {
45  if (!data->has_property(FP_DOT))
46  SG_ERROR("Specified features are not of type CDotFeatures\n");
47  set_features((CDotFeatures*) data);
48  }
49 
51 
52  SGVector<float64_t> train_labels=labels->get_labels();
53  int32_t num_feat=features->get_dim_feature_space();
54  int32_t num_vec=features->get_num_vectors();
55 
56  ASSERT(num_vec==train_labels.vlen);
57  SG_FREE(w);
58 
59  struct options Options;
60  struct data Data;
61  struct vector_double Weights;
62  struct vector_double Outputs;
63 
64  Data.l=num_vec;
65  Data.m=num_vec;
66  Data.u=0;
67  Data.n=num_feat+1;
68  Data.nz=num_feat+1;
69  Data.Y=train_labels.vector;
70  Data.features=features;
71  Data.C = SG_MALLOC(float64_t, Data.l);
72 
73  Options.algo = SVM;
74  Options.lambda=1/(2*get_C1());
75  Options.lambda_u=1/(2*get_C1());
76  Options.S=10000;
77  Options.R=0.5;
78  Options.epsilon = get_epsilon();
79  Options.cgitermax=10000;
80  Options.mfnitermax=50;
81  Options.Cp = get_C2()/get_C1();
82  Options.Cn = 1;
83 
84  if (use_bias)
85  Options.bias=1.0;
86  else
87  Options.bias=0.0;
88 
89  for (int32_t i=0;i<num_vec;i++)
90  {
91  if(train_labels.vector[i]>0)
92  Data.C[i]=Options.Cp;
93  else
94  Data.C[i]=Options.Cn;
95  }
96  ssl_train(&Data, &Options, &Weights, &Outputs);
97  ASSERT(Weights.vec && Weights.d==num_feat+1);
98 
99  float64_t sgn=train_labels.vector[0];
100  for (int32_t i=0; i<num_feat+1; i++)
101  Weights.vec[i]*=sgn;
102 
103  set_w(SGVector<float64_t>(Weights.vec, num_feat));
104  set_bias(Weights.vec[num_feat]);
105 
106  SG_FREE(Data.C);
107  SG_FREE(Outputs.vec);
108  train_labels.free_vector();
109  return true;
110 }
bool has_property(EFeatureProperty p)
Definition: Features.cpp:337
SGVector< float64_t > get_labels()
Definition: Labels.cpp:144
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:35
bool use_bias
Definition: SVMLin.h:112
virtual int32_t get_num_vectors() const =0
#define SG_ERROR(...)
Definition: SGIO.h:75
void set_w(SGVector< float64_t > src_w)
Definition: LinearMachine.h:93
float64_t get_C2()
Definition: SVMLin.h:63
Features that support dot products among other operations.
Definition: DotFeatures.h:41
virtual int32_t get_dim_feature_space() const =0
CLabels * labels
Definition: Machine.h:251
float64_t get_epsilon()
Definition: SVMLin.h:87
virtual void free_vector()
Definition: DataType.h:212
#define ASSERT(x)
Definition: SGIO.h:102
virtual ~CSVMLin()
Definition: SVMLin.cpp:35
float64_t get_C1()
Definition: SVMLin.h:57
double float64_t
Definition: common.h:56
Class LinearMachine is a generic interface for all kinds of linear machines like classifiers.
Definition: LinearMachine.h:61
#define SG_FREE(ptr)
Definition: memory.h:39
CDotFeatures * features
virtual bool train_machine(CFeatures *data=NULL)
Definition: SVMLin.cpp:39
The class Features is the base class of all feature objects.
Definition: Features.h:56
virtual void set_features(CDotFeatures *feat)
void set_bias(float64_t b)
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:63
#define SG_MALLOC(type, len)
Definition: memory.h:36
index_t vlen
Definition: DataType.h:248
void ssl_train(struct data *Data, struct options *Options, struct vector_double *Weights, struct vector_double *Outputs)
Definition: ssl.cpp:33

SHOGUN Machine Learning Toolbox - Documentation