Visual Servoing Platform  version 3.0.1
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Groups Pages
vpGEMM.h
1 /****************************************************************************
2  *
3  * This file is part of the ViSP software.
4  * Copyright (C) 2005 - 2017 by Inria. All rights reserved.
5  *
6  * This software is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public License
8  * ("GPL") version 2 as published by the Free Software Foundation.
9  * See the file LICENSE.txt at the root directory of this source
10  * distribution for additional information about the GNU GPL.
11  *
12  * For using ViSP with software that can not be combined with the GNU
13  * GPL, please contact Inria about acquiring a ViSP Professional
14  * Edition License.
15  *
16  * See http://visp.inria.fr for more information.
17  *
18  * This software was developed at:
19  * Inria Rennes - Bretagne Atlantique
20  * Campus Universitaire de Beaulieu
21  * 35042 Rennes Cedex
22  * France
23  *
24  * If you have questions regarding the use of this file, please contact
25  * Inria at visp@inria.fr
26  *
27  * This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING THE
28  * WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.
29  *
30  * Description:
31  * Matrix generalized multiplication.
32  *
33  * Authors:
34  * Laneurit Jean
35  *
36  *****************************************************************************/
37 
38 
39 #ifndef __VP_GEMM__
40 #define __VP_GEMM__
41 
42 #include <visp3/core/vpArray2D.h>
43 #include <visp3/core/vpException.h>
44 
45 const vpArray2D<double> null(0,0);
46 
57 typedef enum {
58  VP_GEMM_A_T=1,
59  VP_GEMM_B_T=2,
60  VP_GEMM_C_T=4,
61 } vpGEMMmethod;
62 
63 template<unsigned int>
64 inline void GEMMsize(const vpArray2D<double> & /*A*/,const vpArray2D<double> & /*B*/, unsigned int &/*Arows*/, unsigned int &/*Acols*/, unsigned int &/*Brows*/, unsigned int &/*Bcols*/)
65 {}
66 
67 template<>
68 void inline GEMMsize<0>(const vpArray2D<double> & A,const vpArray2D<double> & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
69 {
70  Arows= A.getRows();
71  Acols= A.getCols();
72  Brows= B.getRows();
73  Bcols= B.getCols();
74 }
75 
76 template<>
77 inline void GEMMsize<1>(const vpArray2D<double> & A,const vpArray2D<double> & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
78 {
79  Arows= A.getCols();
80  Acols= A.getRows();
81  Brows= B.getRows();
82  Bcols= B.getCols();
83 }
84 template<>
85 inline void GEMMsize<2>(const vpArray2D<double> & A,const vpArray2D<double> & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
86 {
87  Arows= A.getRows();
88  Acols= A.getCols();
89  Brows= B.getCols();
90  Bcols= B.getRows();
91 }
92 template<>
93 inline void GEMMsize<3>(const vpArray2D<double> & A,const vpArray2D<double> & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
94 {
95  Arows= A.getCols();
96  Acols= A.getRows();
97  Brows= B.getCols();
98  Bcols= B.getRows();
99 }
100 
101 template<>
102 inline void GEMMsize<4>(const vpArray2D<double> & A,const vpArray2D<double> & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
103 {
104  Arows= A.getRows();
105  Acols= A.getCols();
106  Brows= B.getRows();
107  Bcols= B.getCols();
108 }
109 
110 template<>
111 inline void GEMMsize<5>(const vpArray2D<double> & A,const vpArray2D<double> & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
112 {
113  Arows= A.getCols();
114  Acols= A.getRows();
115  Brows= B.getRows();
116  Bcols= B.getCols();
117 }
118 
119 template<>
120 inline void GEMMsize<6>(const vpArray2D<double> & A,const vpArray2D<double> & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
121 {
122  Arows= A.getRows();
123  Acols= A.getCols();
124  Brows= B.getCols();
125  Bcols= B.getRows();
126 }
127 
128 template<>
129 inline void GEMMsize<7>(const vpArray2D<double> & A,const vpArray2D<double> & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
130 {
131  Arows= A.getCols();
132  Acols= A.getRows();
133  Brows= B.getCols();
134  Bcols= B.getRows();
135 }
136 
137 template<unsigned int>
138 inline void GEMM1(const unsigned int &/*Arows*/,const unsigned int &/*Brows*/, const unsigned int &/*Bcols*/, const vpArray2D<double> & /*A*/, const vpArray2D<double> & /*B*/, const double & /*alpha*/,vpArray2D<double> &/*D*/){}
139 
140 template<>
141 inline void GEMM1<0>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A, const vpArray2D<double> & B, const double & alpha,vpArray2D<double> &D)
142 {
143  for(unsigned int r=0;r<Arows;r++)
144  for(unsigned int c=0;c<Bcols;c++){
145  double sum=0;
146  for(unsigned int n=0;n<Brows;n++)
147  sum+=A[r][n]*B[n][c]*alpha;
148  D[r][c]=sum;
149  }
150 }
151 
152 template<>
153 inline void GEMM1<1>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A, const vpArray2D<double> & B, const double & alpha,vpArray2D<double> &D)
154 {
155  for(unsigned int r=0;r<Arows;r++)
156  for(unsigned int c=0;c<Bcols;c++){
157  double sum=0;
158  for(unsigned int n=0;n<Brows;n++)
159  sum+=A[n][r]*B[n][c]*alpha;
160  D[r][c]=sum;
161  }
162 }
163 
164 template<>
165 inline void GEMM1<2>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A,const vpArray2D<double> & B, const double & alpha,vpArray2D<double> &D)
166 {
167  for(unsigned int r=0;r<Arows;r++)
168  for(unsigned int c=0;c<Bcols;c++){
169  double sum=0;
170  for(unsigned int n=0;n<Brows;n++)
171  sum+=A[r][n]*B[c][n]*alpha;
172  D[r][c]=sum;
173  }
174 }
175 
176 template<>
177 inline void GEMM1<3>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A,const vpArray2D<double> & B, const double & alpha,vpArray2D<double> &D)
178 {
179  for(unsigned int r=0;r<Arows;r++)
180  for(unsigned int c=0;c<Bcols;c++){
181  double sum=0;
182  for(unsigned int n=0;n<Brows;n++)
183  sum+=A[n][r]*B[c][n]*alpha;
184  D[r][c]=sum;
185  }
186 }
187 
188 template<unsigned int>
189 inline void GEMM2(const unsigned int &/*Arows*/,const unsigned int &/*Brows*/, const unsigned int &/*Bcols*/, const vpArray2D<double> & /*A*/,const vpArray2D<double> & /*B*/, const double & /*alpha*/, const vpArray2D<double> & /*C*/ , const double &/*beta*/, vpArray2D<double> &/*D*/)
190 {}
191 
192 template<>
193 inline void GEMM2<0>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A,const vpArray2D<double> & B, const double & alpha, const vpArray2D<double> & C , const double &beta, vpArray2D<double> &D)
194 {
195  for(unsigned int r=0;r<Arows;r++)
196  for(unsigned int c=0;c<Bcols;c++){
197  double sum=0;
198  for(unsigned int n=0;n<Brows;n++)
199  sum+=A[r][n]*B[n][c]*alpha;
200  D[r][c]=sum+C[r][c]*beta;
201  }
202 }
203 
204 template<>
205 inline void GEMM2<1>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A,const vpArray2D<double> & B, const double & alpha, const vpArray2D<double> & C , const double &beta, vpArray2D<double> &D)
206 {
207  for(unsigned int r=0;r<Arows;r++)
208  for(unsigned int c=0;c<Bcols;c++){
209  double sum=0;
210  for(unsigned int n=0;n<Brows;n++)
211  sum+=A[n][r]*B[n][c]*alpha;
212  D[r][c]=sum+C[r][c]*beta;
213  }
214 }
215 
216 template<>
217 inline void GEMM2<2>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A,const vpArray2D<double> & B, const double & alpha, const vpArray2D<double> & C , const double &beta, vpArray2D<double> &D)
218 {
219  for(unsigned int r=0;r<Arows;r++)
220  for(unsigned int c=0;c<Bcols;c++){
221  double sum=0;
222  for(unsigned int n=0;n<Brows;n++)
223  sum+=A[r][n]*B[c][n]*alpha;
224  D[r][c]=sum+C[r][c]*beta;
225  }
226 }
227 
228 template<>
229 inline void GEMM2<3>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A,const vpArray2D<double> & B, const double & alpha, const vpArray2D<double> & C , const double &beta, vpArray2D<double> &D)
230 {
231  for(unsigned int r=0;r<Arows;r++)
232  for(unsigned int c=0;c<Bcols;c++){
233  double sum=0;
234  for(unsigned int n=0;n<Brows;n++)
235  sum+=A[n][r]*B[c][n]*alpha;
236  D[r][c]=sum+C[r][c]*beta;
237  }
238 }
239 
240 
241 template<>
242 inline void GEMM2<4>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A,const vpArray2D<double> & B, const double & alpha, const vpArray2D<double> & C , const double &beta, vpArray2D<double> &D)
243 {
244  for(unsigned int r=0;r<Arows;r++)
245  for(unsigned int c=0;c<Bcols;c++){
246  double sum=0;
247  for(unsigned int n=0;n<Brows;n++)
248  sum+=A[r][n]*B[n][c]*alpha;
249  D[r][c]=sum+C[c][r]*beta;
250  }
251 }
252 
253 template<>
254 inline void GEMM2<5>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A,const vpArray2D<double> & B, const double & alpha, const vpArray2D<double> & C , const double &beta, vpArray2D<double> &D)
255 {
256  for(unsigned int r=0;r<Arows;r++)
257  for(unsigned int c=0;c<Bcols;c++){
258  double sum=0;
259  for(unsigned int n=0;n<Brows;n++)
260  sum+=A[n][r]*B[n][c]*alpha;
261  D[r][c]=sum+C[c][r]*beta;
262  }
263 
264 }
265 
266 template<>
267 inline void GEMM2<6>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A,const vpArray2D<double> & B, const double & alpha, const vpArray2D<double> & C , const double &beta, vpArray2D<double> &D)
268 {
269  for(unsigned int r=0;r<Arows;r++)
270  for(unsigned int c=0;c<Bcols;c++){
271  double sum=0;
272  for(unsigned int n=0;n<Brows;n++)
273  sum+=A[r][n]*B[c][n]*alpha;
274  D[r][c]=sum+C[c][r]*beta;
275  }
276 }
277 
278 template<>
279 inline void GEMM2<7>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols, const vpArray2D<double> & A,const vpArray2D<double> & B, const double & alpha, const vpArray2D<double> & C , const double &beta, vpArray2D<double> &D)
280 {
281  for(unsigned int r=0;r<Arows;r++)
282  for(unsigned int c=0;c<Bcols;c++){
283  double sum=0;
284  for(unsigned int n=0;n<Brows;n++)
285  sum+=A[n][r]*B[c][n]*alpha;
286  D[r][c]=sum+C[c][r]*beta;
287  }
288 }
289 
290 template<unsigned int T>
291 inline void vpTGEMM(const vpArray2D<double> & A, const vpArray2D<double> & B, const double & alpha ,const vpArray2D<double> & C, const double & beta, vpArray2D<double> & D)
292 {
293  unsigned int Arows;
294  unsigned int Acols;
295  unsigned int Brows;
296  unsigned int Bcols;
297 
298  GEMMsize<T>(A,B,Arows,Acols,Brows,Bcols);
299 
300  try {
301  if ((Arows != D.getRows()) || (Bcols != D.getCols())) D.resize(Arows,Bcols);
302  }
303  catch(...) {
304  throw ;
305  }
306 
307  if (Acols != Brows) {
309  "In vpGEMM, cannot multiply (%dx%d) matrix by (%dx%d) matrix",
310  Arows, Acols, Brows, Bcols)) ;
311  }
312 
313  if(C.getRows()!=0 && C.getCols()!=0){
314  if ((Arows != C.getRows()) || (Bcols != C.getCols())) {
316  "In vpGEMM, cannot add resulting (%dx%d) matrix to (%dx%d) matrix",
317  Arows, Bcols, C.getRows(), C.getCols())) ;
318  }
319 
320  GEMM2<T>(Arows,Brows,Bcols,A,B,alpha,C,beta,D);
321  }else{
322  GEMM1<T>(Arows,Brows,Bcols,A,B,alpha,D);
323  }
324 
325 }
326 
358 inline void vpGEMM(const vpArray2D<double> & A, const vpArray2D<double> & B,
359  const double & alpha, const vpArray2D<double> & C,
360  const double & beta, vpArray2D<double> & D, const unsigned int &ops=0)
361 {
362  switch(ops){
363  case 0 :
364  vpTGEMM<0>( A, B, alpha , C, beta, D);
365  break;
366  case 1 :
367  vpTGEMM<1>( A, B, alpha , C, beta, D);
368  break;
369  case 2 :
370  vpTGEMM<2>( A, B, alpha , C, beta, D);
371  break;
372  case 3 :
373  vpTGEMM<3>( A, B, alpha , C, beta, D);
374  break;
375  case 4 :
376  vpTGEMM<4>( A, B, alpha , C, beta, D);
377  break;
378  case 5 :
379  vpTGEMM<5>( A, B, alpha , C, beta, D);
380  break;
381  case 6 :
382  vpTGEMM<6>( A, B, alpha , C, beta, D);
383  break;
384  case 7 :
385  vpTGEMM<7>( A, B, alpha , C, beta, D);
386  break;
387  default:
389  "Operation on vpGEMM not implemented")) ;
390  break;
391  }
392 }
393 
394 #endif
void resize(const unsigned int nrows, const unsigned int ncols, const bool flagNullify=true)
Definition: vpArray2D.h:167
error that can be emited by ViSP classes.
Definition: vpException.h:73
unsigned int getCols() const
Return the number of columns of the 2D array.
Definition: vpArray2D.h:154
unsigned int getRows() const
Return the number of rows of the 2D array.
Definition: vpArray2D.h:152
vpGEMMmethod
Definition: vpGEMM.h:57
void vpGEMM(const vpArray2D< double > &A, const vpArray2D< double > &B, const double &alpha, const vpArray2D< double > &C, const double &beta, vpArray2D< double > &D, const unsigned int &ops=0)
Definition: vpGEMM.h:358