SHOGUN  v3.2.0
MulticlassMachine.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2011 Soeren Sonnenburg
8  * Written (W) 2012 Fernando José Iglesias García and Sergey Lisitsyn
9  * Written (W) 2013 Shell Hu and Heiko Strathmann
10  * Copyright (C) 2012 Sergey Lisitsyn, Fernando José Iglesias Garcia
11  */
12 
17 #include <shogun/base/Parameter.h>
21 
22 using namespace shogun;
23 
25 : CBaseMulticlassMachine(), m_multiclass_strategy(new CMulticlassOneVsRestStrategy()),
26  m_machine(NULL)
27 {
29  register_parameters();
30 }
31 
33  CMulticlassStrategy *strategy,
34  CMachine* machine, CLabels* labs)
36 {
37  SG_REF(strategy);
38  set_labels(labs);
39  SG_REF(machine);
40  m_machine = machine;
41  register_parameters();
42 
43  if (labs)
44  init_strategy();
45 }
46 
48 {
51 }
52 
54 {
56  if (lab)
57  init_strategy();
58 }
59 
60 void CMulticlassMachine::register_parameters()
61 {
62  SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE);
63  SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE);
64 }
65 
67 {
68  int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
70 }
71 
73 {
74  CMachine *machine = (CMachine*)m_machines->get_element(i);
75  ASSERT(machine)
76  CBinaryLabels* output = machine->apply_binary();
77  SG_UNREF(machine);
78  return output;
79 }
80 
82 {
83  CMachine *machine = get_machine(i);
84  float64_t output = 0.0;
85  // dirty hack
86  if (dynamic_cast<CLinearMachine*>(machine))
87  output = ((CLinearMachine*)machine)->apply_one(num);
88  if (dynamic_cast<CKernelMachine*>(machine))
89  output = ((CKernelMachine*)machine)->apply_one(num);
90  SG_UNREF(machine);
91  return output;
92 }
93 
95 {
96  SG_DEBUG("entering %s::apply_multiclass(%s at %p)\n",
97  get_name(), data ? data->get_name() : "NULL", data);
98 
99  CMulticlassLabels* return_labels=NULL;
100 
101  if (data)
103  else
105 
106  if (is_ready())
107  {
108  /* num vectors depends on whether data is provided */
109  int32_t num_vectors=data ? data->get_num_vectors() :
111 
112  int32_t num_machines=m_machines->get_num_elements();
113  if (num_machines <= 0)
114  SG_ERROR("num_machines = %d, did you train your machine?", num_machines)
115 
116  CMulticlassLabels* result=new CMulticlassLabels(num_vectors);
117 
118  // if outputs are prob, only one confidence for each class
119  int32_t num_classes=m_multiclass_strategy->get_num_classes();
121 
122  if (heuris!=PROB_HEURIS_NONE)
123  result->allocate_confidences_for(num_classes);
124  else
125  result->allocate_confidences_for(num_machines);
126 
127  CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
128  SGVector<float64_t> As(num_machines);
129  SGVector<float64_t> Bs(num_machines);
130 
131  for (int32_t i=0; i<num_machines; ++i)
132  {
133  outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
134 
135  if (heuris==OVA_SOFTMAX)
136  {
137  CStatistics::SigmoidParamters params = CStatistics::fit_sigmoid(outputs[i]->get_values());
138  As[i] = params.a;
139  Bs[i] = params.b;
140  }
141 
142  if (heuris!=PROB_HEURIS_NONE && heuris!=OVA_SOFTMAX)
143  outputs[i]->scores_to_probabilities(0,0);
144  }
145 
146  SGVector<float64_t> output_for_i(num_machines);
147  SGVector<float64_t> r_output_for_i(num_machines);
148  if (heuris!=PROB_HEURIS_NONE)
149  r_output_for_i.resize_vector(num_classes);
150 
151  for (int32_t i=0; i<num_vectors; i++)
152  {
153  for (int32_t j=0; j<num_machines; j++)
154  output_for_i[j] = outputs[j]->get_value(i);
155 
156  if (heuris==PROB_HEURIS_NONE)
157  {
158  r_output_for_i = output_for_i;
159  }
160  else
161  {
162  if (heuris==OVA_SOFTMAX)
163  m_multiclass_strategy->rescale_outputs(output_for_i,As,Bs);
164  else
165  m_multiclass_strategy->rescale_outputs(output_for_i);
166 
167  // only first num_classes are returned
168  for (int32_t r=0; r<num_classes; r++)
169  r_output_for_i[r] = output_for_i[r];
170 
171  SG_DEBUG("%s::apply_multiclass(): sum(r_output_for_i) = %f\n",
172  get_name(), SGVector<float64_t>::sum(r_output_for_i.vector,num_classes));
173  }
174 
175  // use rescaled outputs for label decision
176  result->set_label(i, m_multiclass_strategy->decide_label(r_output_for_i));
177  result->set_multiclass_confidences(i, r_output_for_i);
178  }
179 
180  for (int32_t i=0; i < num_machines; ++i)
181  SG_UNREF(outputs[i]);
182 
183  SG_FREE(outputs);
184 
185  return_labels=result;
186  }
187  else
188  SG_ERROR("Not ready")
189 
190 
191  SG_DEBUG("leaving %s::apply_multiclass(%s at %p)\n",
192  get_name(), data ? data->get_name() : "NULL", data);
193  return return_labels;
194 }
195 
197 {
198  CMulticlassMultipleOutputLabels* return_labels=NULL;
199 
200  if (data)
202  else
204 
205  if (is_ready())
206  {
207  /* num vectors depends on whether data is provided */
208  int32_t num_vectors=data ? data->get_num_vectors() :
210 
211  int32_t num_machines=m_machines->get_num_elements();
212  if (num_machines <= 0)
213  SG_ERROR("num_machines = %d, did you train your machine?", num_machines)
214  REQUIRE(n_outputs<=num_machines,"You request more outputs than machines available")
215 
217  CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
218 
219  for (int32_t i=0; i < num_machines; ++i)
220  outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
221 
222  SGVector<float64_t> output_for_i(num_machines);
223  for (int32_t i=0; i<num_vectors; i++)
224  {
225  for (int32_t j=0; j<num_machines; j++)
226  output_for_i[j] = outputs[j]->get_value(i);
227 
228  result->set_label(i, m_multiclass_strategy->decide_label_multiple_output(output_for_i, n_outputs));
229  }
230 
231  for (int32_t i=0; i < num_machines; ++i)
232  SG_UNREF(outputs[i]);
233 
234  SG_FREE(outputs);
235 
236  return_labels=result;
237  }
238  else
239  SG_ERROR("Not ready")
240 
241  return return_labels;
242 }
243 
245 {
247 
248  if ( !data && !is_ready() )
249  SG_ERROR("Please provide training data.\n")
250  else
252 
254  CBinaryLabels* train_labels = new CBinaryLabels(get_num_rhs_vectors());
255  SG_REF(train_labels);
256  m_machine->set_labels(train_labels);
257 
260  {
262  if (subset.vlen)
263  {
264  train_labels->add_subset(subset);
265  add_machine_subset(subset);
266  }
267 
268  m_machine->train();
270 
271  if (subset.vlen)
272  {
273  train_labels->remove_subset();
275  }
276  }
277 
279  SG_UNREF(train_labels);
280 
281  return true;
282 }
283 
285 {
287 
290 
291  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
292  outputs[i] = get_submachine_output(i, vec_idx);
293 
294  float64_t result = m_multiclass_strategy->decide_label(outputs);
295 
296  return result;
297 }
void allocate_confidences_for(int32_t n_classes)
virtual const char * get_name() const =0
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:35
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
#define SG_UNREF(x)
Definition: SGRefObject.h:35
Multiclass Labels for multi-class classification with multiple labels
virtual int32_t get_num_vectors() const =0
CSGObject * get_element(int32_t index) const
CLabels * m_labels
Definition: Machine.h:356
#define SG_ERROR(...)
Definition: SGIO.h:131
virtual void rescale_outputs(SGVector< float64_t > outputs)
#define REQUIRE(x,...)
Definition: SGIO.h:208
CMachine * get_machine(int32_t num) const
A generic KernelMachine interface.
Definition: KernelMachine.h:50
virtual CMachine * get_machine_from_trained(CMachine *machine)=0
virtual void add_machine_subset(SGVector< index_t > subset)=0
virtual float64_t get_submachine_output(int32_t i, int32_t num)
A generic learning machine interface.
Definition: Machine.h:138
bool set_label(int32_t idx, float64_t label)
Multiclass Labels for multi-class classification
virtual bool init_machine_for_train(CFeatures *data)=0
virtual CBinaryLabels * apply_binary(CFeatures *data=NULL)
Definition: Machine.cpp:218
static SigmoidParamters fit_sigmoid(SGVector< float64_t > scores)
void set_num_classes(int32_t num_classes)
virtual bool is_ready()=0
#define ASSERT(x)
Definition: SGIO.h:203
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:102
CMulticlassStrategy * m_multiclass_strategy
double float64_t
Definition: common.h:48
#define SG_REF(x)
Definition: SGRefObject.h:34
virtual CBinaryLabels * get_submachine_outputs(int32_t i)
virtual void remove_subset()
Definition: Labels.cpp:45
virtual void add_subset(SGVector< index_t > subset)
Definition: Labels.cpp:40
virtual CMulticlassMultipleOutputLabels * apply_multiclass_multiple_output(CFeatures *data=NULL, int32_t n_outputs=5)
virtual bool init_machines_for_apply(CFeatures *data)=0
Class LinearMachine is a generic interface for all kinds of linear machines like classifiers.
Definition: LinearMachine.h:61
virtual const char * get_name() const
virtual float64_t apply_one(int32_t vec_idx)
virtual void remove_machine_subset()=0
void set_multiclass_confidences(int32_t i, SGVector< float64_t > confidences)
#define SG_DEBUG(...)
Definition: SGIO.h:109
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:16
The class Features is the base class of all feature objects.
Definition: Features.h:62
virtual bool train(CFeatures *data=NULL)
Definition: Machine.cpp:49
void scores_to_probabilities(float64_t a=0, float64_t b=0)
bool set_label(int32_t idx, SGVector< index_t > label)
Binary Labels for binary classification
Definition: BinaryLabels.h:36
class MulticlassStrategy used to construct generic multiclass classifiers with ensembles of binary cl...
EProbHeuristicType get_prob_heuris()
virtual bool train_has_more()=0
virtual SGVector< int32_t > train_prepare_next()
void resize_vector(int32_t n)
Definition: SGVector.cpp:307
multiclass one vs rest strategy used to train generic multiclass machines for K-class problems with b...
#define SG_ADD(...)
Definition: SGObject.h:71
virtual bool train_machine(CFeatures *data=NULL)
virtual int32_t get_num_rhs_vectors()=0
virtual SGVector< index_t > decide_label_multiple_output(SGVector< float64_t > outputs, int32_t n_outputs)
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:75
static CMulticlassLabels * to_multiclass(CLabels *base_labels)
virtual void set_labels(CLabels *lab)
index_t vlen
Definition: SGVector.h:706
virtual void train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels)
virtual int32_t decide_label(SGVector< float64_t > outputs)=0

SHOGUN Machine Learning Toolbox - Documentation