ViSP  2.10.0
vpGEMM.h
1 /****************************************************************************
2  *
3  * $Id: vpGEMM.h 4574 2014-01-09 08:48:51Z fspindle $
4  *
5  * This file is part of the ViSP software.
6  * Copyright (C) 2005 - 2014 by INRIA. All rights reserved.
7  *
8  * This software is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU General Public License
10  * ("GPL") version 2 as published by the Free Software Foundation.
11  * See the file LICENSE.txt at the root directory of this source
12  * distribution for additional information about the GNU GPL.
13  *
14  * For using ViSP with software that can not be combined with the GNU
15  * GPL, please contact INRIA about acquiring a ViSP Professional
16  * Edition License.
17  *
18  * See http://www.irisa.fr/lagadic/visp/visp.html for more information.
19  *
20  * This software was developed at:
21  * INRIA Rennes - Bretagne Atlantique
22  * Campus Universitaire de Beaulieu
23  * 35042 Rennes Cedex
24  * France
25  * http://www.irisa.fr/lagadic
26  *
27  * If you have questions regarding the use of this file, please contact
28  * INRIA at visp@inria.fr
29  *
30  * This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING THE
31  * WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.
32  *
33  *
34  * Description:
35  * Matrix generalized multiplication.
36  *
37  * Authors:
38  * Laneurit Jean
39  *
40  *****************************************************************************/
41 
42 
43 #ifndef __VP_GEMM__
44 #define __VP_GEMM__
45 
46 #include <visp/vpMatrix.h>
47 #include <visp/vpException.h>
48 #include <visp/vpMatrixException.h>
49 #include <visp/vpDebug.h>
50 
51 const vpMatrix null(0,0);
52 
63 typedef enum {
64  VP_GEMM_A_T=1,
65  VP_GEMM_B_T=2,
66  VP_GEMM_C_T=4,
67 } vpGEMMmethod;
68 
69 
70 
71 
72 template<unsigned int> inline void GEMMsize(const vpMatrix & /*A*/,const vpMatrix & /*B*/, unsigned int &/*Arows*/, unsigned int &/*Acols*/, unsigned int &/*Brows*/, unsigned int &/*Bcols*/){}
73 
74  template<> void inline GEMMsize<0>(const vpMatrix & A,const vpMatrix & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols){
75  Arows= A.getRows();
76  Acols= A.getCols();
77  Brows= B.getRows();
78  Bcols= B.getCols();
79 }
80 
81  template<> inline void GEMMsize<1>(const vpMatrix & A,const vpMatrix & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols){
82  Arows= A.getCols();
83  Acols= A.getRows();
84  Brows= B.getRows();
85  Bcols= B.getCols();
86 }
87  template<> inline void GEMMsize<2>(const vpMatrix & A,const vpMatrix & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols){
88  Arows= A.getRows();
89  Acols= A.getCols();
90  Brows= B.getCols();
91  Bcols= B.getRows();
92 }
93  template<> inline void GEMMsize<3>(const vpMatrix & A,const vpMatrix & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols){
94  Arows= A.getCols();
95  Acols= A.getRows();
96  Brows= B.getCols();
97  Bcols= B.getRows();
98 }
99 
100  template<> inline void GEMMsize<4>(const vpMatrix & A,const vpMatrix & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols){
101  Arows= A.getRows();
102  Acols= A.getCols();
103  Brows= B.getRows();
104  Bcols= B.getCols();
105 }
106 
107  template<> inline void GEMMsize<5>(const vpMatrix & A,const vpMatrix & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols){
108  Arows= A.getCols();
109  Acols= A.getRows();
110  Brows= B.getRows();
111  Bcols= B.getCols();
112 }
113 
114  template<> inline void GEMMsize<6>(const vpMatrix & A,const vpMatrix & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols){
115  Arows= A.getRows();
116  Acols= A.getCols();
117  Brows= B.getCols();
118  Bcols= B.getRows();
119 }
120 
121  template<> inline void GEMMsize<7>(const vpMatrix & A,const vpMatrix & B, unsigned int &Arows, unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols){
122  Arows= A.getCols();
123  Acols= A.getRows();
124  Brows= B.getCols();
125  Bcols= B.getRows();
126 }
127 
128 
129 
130  template<unsigned int> inline void GEMM1(const unsigned int &/*Arows*/,const unsigned int &/*Brows*/, const unsigned int &/*Bcols*/, const vpMatrix & /*A*/, const vpMatrix & /*B*/, const double & /*alpha*/,vpMatrix &/*D*/){}
131 
132  template<> inline void GEMM1<0>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A, const vpMatrix & B, const double & alpha,vpMatrix &D){
133  for(unsigned int r=0;r<Arows;r++)
134  for(unsigned int c=0;c<Bcols;c++){
135  double sum=0;
136  for(unsigned int n=0;n<Brows;n++)
137  sum+=A[r][n]*B[n][c]*alpha;
138  D[r][c]=sum;
139  }
140 }
141 
142  template<> inline void GEMM1<1>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A, const vpMatrix & B, const double & alpha,vpMatrix &D){
143  for(unsigned int r=0;r<Arows;r++)
144  for(unsigned int c=0;c<Bcols;c++){
145  double sum=0;
146  for(unsigned int n=0;n<Brows;n++)
147  sum+=A[n][r]*B[n][c]*alpha;
148  D[r][c]=sum;
149  }
150 }
151 
152  template<> inline void GEMM1<2>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A,const vpMatrix & B, const double & alpha,vpMatrix &D){
153  for(unsigned int r=0;r<Arows;r++)
154  for(unsigned int c=0;c<Bcols;c++){
155  double sum=0;
156  for(unsigned int n=0;n<Brows;n++)
157  sum+=A[r][n]*B[c][n]*alpha;
158  D[r][c]=sum;
159  }
160 }
161 
162  template<> inline void GEMM1<3>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A,const vpMatrix & B, const double & alpha,vpMatrix &D){
163  for(unsigned int r=0;r<Arows;r++)
164  for(unsigned int c=0;c<Bcols;c++){
165  double sum=0;
166  for(unsigned int n=0;n<Brows;n++)
167  sum+=A[n][r]*B[c][n]*alpha;
168  D[r][c]=sum;
169  }
170 }
171 
172  template<unsigned int> inline void GEMM2(const unsigned int &/*Arows*/,const unsigned int &/*Brows*/, const unsigned int &/*Bcols*/, const vpMatrix & /*A*/,const vpMatrix & /*B*/, const double & /*alpha*/, const vpMatrix & /*C*/ , const double &/*beta*/, vpMatrix &/*D*/){}
173 
174  template<> inline void GEMM2<0>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A,const vpMatrix & B, const double & alpha, const vpMatrix & C , const double &beta, vpMatrix &D){
175 
176  for(unsigned int r=0;r<Arows;r++)
177  for(unsigned int c=0;c<Bcols;c++){
178  double sum=0;
179  for(unsigned int n=0;n<Brows;n++)
180  sum+=A[r][n]*B[n][c]*alpha;
181  D[r][c]=sum+C[r][c]*beta;
182  }
183 }
184 
185  template<> inline void GEMM2<1>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A,const vpMatrix & B, const double & alpha, const vpMatrix & C , const double &beta, vpMatrix &D){
186  for(unsigned int r=0;r<Arows;r++)
187  for(unsigned int c=0;c<Bcols;c++){
188  double sum=0;
189  for(unsigned int n=0;n<Brows;n++)
190  sum+=A[n][r]*B[n][c]*alpha;
191  D[r][c]=sum+C[r][c]*beta;
192  }
193 }
194 
195  template<> inline void GEMM2<2>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A,const vpMatrix & B, const double & alpha, const vpMatrix & C , const double &beta, vpMatrix &D){
196  for(unsigned int r=0;r<Arows;r++)
197  for(unsigned int c=0;c<Bcols;c++){
198  double sum=0;
199  for(unsigned int n=0;n<Brows;n++)
200  sum+=A[r][n]*B[c][n]*alpha;
201  D[r][c]=sum+C[r][c]*beta;
202  }
203 }
204 
205  template<> inline void GEMM2<3>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A,const vpMatrix & B, const double & alpha, const vpMatrix & C , const double &beta, vpMatrix &D){
206  for(unsigned int r=0;r<Arows;r++)
207  for(unsigned int c=0;c<Bcols;c++){
208  double sum=0;
209  for(unsigned int n=0;n<Brows;n++)
210  sum+=A[n][r]*B[c][n]*alpha;
211  D[r][c]=sum+C[r][c]*beta;
212  }
213 }
214 
215 
216  template<> inline void GEMM2<4>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A,const vpMatrix & B, const double & alpha, const vpMatrix & C , const double &beta, vpMatrix &D){
217  for(unsigned int r=0;r<Arows;r++)
218  for(unsigned int c=0;c<Bcols;c++){
219  double sum=0;
220  for(unsigned int n=0;n<Brows;n++)
221  sum+=A[r][n]*B[n][c]*alpha;
222  D[r][c]=sum+C[c][r]*beta;
223  }
224 }
225 
226  template<> inline void GEMM2<5>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A,const vpMatrix & B, const double & alpha, const vpMatrix & C , const double &beta, vpMatrix &D){
227  for(unsigned int r=0;r<Arows;r++)
228  for(unsigned int c=0;c<Bcols;c++){
229  double sum=0;
230  for(unsigned int n=0;n<Brows;n++)
231  sum+=A[n][r]*B[n][c]*alpha;
232  D[r][c]=sum+C[c][r]*beta;
233  }
234 
235 }
236 
237  template<> inline void GEMM2<6>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A,const vpMatrix & B, const double & alpha, const vpMatrix & C , const double &beta, vpMatrix &D){
238  for(unsigned int r=0;r<Arows;r++)
239  for(unsigned int c=0;c<Bcols;c++){
240  double sum=0;
241  for(unsigned int n=0;n<Brows;n++)
242  sum+=A[r][n]*B[c][n]*alpha;
243  D[r][c]=sum+C[c][r]*beta;
244  }
245 }
246 
247  template<> inline void GEMM2<7>(const unsigned int &Arows,const unsigned int &Brows, const unsigned int &Bcols, const vpMatrix & A,const vpMatrix & B, const double & alpha, const vpMatrix & C , const double &beta, vpMatrix &D){
248  //vpMatrix &D = *dynamic_cast<double***>(Dptr);
249  for(unsigned int r=0;r<Arows;r++)
250  for(unsigned int c=0;c<Bcols;c++){
251  double sum=0;
252  for(unsigned int n=0;n<Brows;n++)
253  sum+=A[n][r]*B[c][n]*alpha;
254  D[r][c]=sum+C[c][r]*beta;
255  }
256 }
257 
258  template<unsigned int T> inline void vpTGEMM(const vpMatrix & A,const vpMatrix & B, const double & alpha ,const vpMatrix & C, const double & beta, vpMatrix & D){
259 
260  unsigned int Arows;
261  unsigned int Acols;
262  unsigned int Brows;
263  unsigned int Bcols;
264 
265 // std::cout << T << std::endl;
266  GEMMsize<T>(A,B,Arows,Acols,Brows,Bcols);
267 // std::cout << Arows<<" " <<Acols << " "<< Brows << " "<< Bcols<<std::endl;
268 
269  try
270  {
271  if ((Arows != D.getRows()) || (Bcols != D.getCols())) D.resize(Arows,Bcols);
272  }
273  catch(vpException me)
274  {
275  vpERROR_TRACE("Error caught") ;
276  std::cout << me << std::endl ;
277  throw ;
278  }
279 
280  if (Acols != Brows)
281  {
282  vpERROR_TRACE("\n\t\tvpMatrix mismatch size in vpGEMM") ;
283  throw(vpMatrixException(vpMatrixException::incorrectMatrixSizeError,"\n\t\tvpMatrix mismatch size in vpGEMM")) ;
284  }
285 
286  if(C.getRows()!=0 && C.getCols()!=0){
287 
288  if ((Arows != C.getRows()) || (Bcols != C.getCols()))
289  {
290  vpERROR_TRACE("\n\t\tvpMatrix mismatch size in vpGEMM") ;
291  throw(vpMatrixException(vpMatrixException::incorrectMatrixSizeError,"\n\t\tvpMatrix mismatch size in vpGEMM")) ;
292  }
293 
294 
295  GEMM2<T>(Arows,Brows,Bcols,A,B,alpha,C,beta,D);
296  }else{
297  GEMM1<T>(Arows,Brows,Bcols,A,B,alpha,D);
298  }
299 
300 }
301 
331 inline void vpGEMM(const vpMatrix & A,const vpMatrix & B, const double & alpha ,const vpMatrix & C, const double & beta, vpMatrix & D, const unsigned int &ops=0){
332  switch(ops){
333  case 0 :
334  vpTGEMM<0>( A, B, alpha , C, beta, D);
335  break;
336  case 1 :
337  vpTGEMM<1>( A, B, alpha , C, beta, D);
338  break;
339  case 2 :
340  vpTGEMM<2>( A, B, alpha , C, beta, D);
341  break;
342  case 3 :
343  vpTGEMM<3>( A, B, alpha , C, beta, D);
344  break;
345  case 4 :
346  vpTGEMM<4>( A, B, alpha , C, beta, D);
347  break;
348  case 5 :
349  vpTGEMM<5>( A, B, alpha , C, beta, D);
350  break;
351  case 6 :
352  vpTGEMM<6>( A, B, alpha , C, beta, D);
353  break;
354  case 7 :
355  vpTGEMM<7>( A, B, alpha , C, beta, D);
356  break;
357  default:
358  vpERROR_TRACE("\n\t\tvpMatrix mismatch operation in vpGEMM") ;
359  throw(vpMatrixException(vpMatrixException::incorrectMatrixSizeError,"\n\t\tvpMatrix mismatch operation in vpGEMM")) ;
360  break;
361  }
362 }
363 
364 #endif
Definition of the vpMatrix class.
Definition: vpMatrix.h:98
void resize(const unsigned int nrows, const unsigned int ncols, const bool nullify=true)
Definition: vpMatrix.cpp:199
#define vpERROR_TRACE
Definition: vpDebug.h:395
error that can be emited by ViSP classes.
Definition: vpException.h:76
void vpGEMM(const vpMatrix &A, const vpMatrix &B, const double &alpha, const vpMatrix &C, const double &beta, vpMatrix &D, const unsigned int &ops=0)
This function performs generalized matrix multiplication: D = alpha*op(A)*op(B) + beta*op(C)...
Definition: vpGEMM.h:331
double sum() const
Definition: vpMatrix.cpp:903
unsigned int getCols() const
Return the number of columns of the matrix.
Definition: vpMatrix.h:163
error that can be emited by the vpMatrix class and its derivates
unsigned int getRows() const
Return the number of rows of the matrix.
Definition: vpMatrix.h:161
vpGEMMmethod
Enumeration of the operations applied on matrices in vpGEMM function.
Definition: vpGEMM.h:63