Free Electron
ConjugateGradient.h
Go to the documentation of this file.
1 /* Copyright (C) 2003-2021 Free Electron Organization
2  Any use of this software requires a license. If a valid license
3  was not distributed with this file, visit freeelectron.org. */
4 
5 /** @file */
6 
7 #ifndef __solve_ConjugateGradient_h__
8 #define __solve_ConjugateGradient_h__
9 
10 #define CG_DEBUG FALSE
11 #define CG_TRACE FALSE
12 
13 #define CG_CHECK_SYMMETRY (FE_CODEGEN<FE_DEBUG)
14 #define CG_CHECK_POSITIVE (FE_CODEGEN<FE_DEBUG)
15 #define CG_CHECK_PEAK (FE_CODEGEN<FE_DEBUG)
16 #define CG_VERIFY (FE_CODEGEN<=FE_DEBUG)
17 #define CG_EXCEPTION (FE_CODEGEN<=FE_DEBUG)
18 
19 namespace fe
20 {
21 namespace ext
22 {
23 
24 /**************************************************************************//**
25  @brief solve Ax=b for x
26 
27  @ingroup solve
28 
29  Uses Conjugate-Gradient. The matrix must be positive-definite
30  and symmetric.
31 
32  The arguments are templated, so any argument types should work,
33  given that they have the appropriate methods and operators.
34 
35  TODO try seeding with previous x instead of clearing to zero.
36 *//***************************************************************************/
37 template <typename MATRIX, typename VECTOR>
39 {
40  public:
41  ConjugateGradient(void):
42  m_threshold(1e-6f) {}
43 
44  void solve(VECTOR& x, const MATRIX& A, const VECTOR& b);
45 
46  void setThreshold(F64 threshold) { m_threshold=threshold; }
47 
48  private:
49  VECTOR r; //* residual
50  VECTOR d; //* direction
51  VECTOR temp; //* persistent temporary
52  VECTOR q; //* A*d
53  F64 m_threshold;
54 };
55 
56 template <typename MATRIX, typename VECTOR>
57 inline void ConjugateGradient<MATRIX,VECTOR>::solve(VECTOR& x,
58  const MATRIX& A, const VECTOR& b)
59 {
60  U32 N=size(b);
61  if(size(x)!=N)
62  {
63  x=b; // adopt size
64  }
65  if(size(q)!=N)
66  {
67  q=b; // adopt size
68  }
69  set(x);
70 
71 #if CG_DEBUG
72  feLog("\nA\n%s\nb=<%s>\n",c_print(A),c_print(b));
73 #endif
74 
75 #if CG_CHECK_SYMMETRY
76  for(U32 iy=0;iy<N;iy++)
77  {
78  for(U32 ix=iy;ix<N;ix++)
79  {
80  if(fabs(A(ix,iy)-A(iy,ix))>1e-9f)
81  {
82  feLog("ConjugateGradient symmetry error %d,%d %.6G %.6G\n",
83  ix,iy,A(ix,iy),A(iy,ix));
84  if(CG_EXCEPTION)
85  {
86  feX("ConjugateGradient::solve","unsymmetrical");
87  }
88  }
89  }
90  }
91 #endif
92 
93  if(magnitudeSquared(b)<m_threshold)
94  {
95 #if CG_DEBUG
96  feLog("ConjugateGradient::solve has trivial solution\n");
97 #endif
98  return;
99  }
100 
101  const MATRIX* pA=&A;
102  const VECTOR* pb=&b;
103 #if CG_CHECK_POSITIVE
104  for(U32 k=0;k<N;k++)
105  {
106  if((*pA)(k,k)<=0.0)
107  {
108  feLog("ConjugateGradient::solve"
109  " non-positive diagonal %d,%d %.6G\n",
110  k,k,(*pA)(k,k));
111  }
112  }
113 #endif
114 #if CG_CHECK_PEAK
115  F64 peak=0.0;
116  U32 peak_i;
117  U32 peak_j;
118  for(U32 i=0;i<N;i++)
119  {
120  for(U32 j=0;j<N;j++)
121  {
122  F64 value=fabs((*pA)(i,j));
123  if(peak<value)
124  {
125  peak=value;
126  peak_i=i;
127  peak_j=j;
128  }
129  }
130  }
131  if(peak_i!=peak_j)
132  {
133  feLog("ConjugateGradient::solve non-diagonal peak %d,%d %.6G\n",
134  peak_i,peak_j,peak);
135  }
136 #endif
137 
138  U32 i=0;
139  r= *pb;
140  d=r;
141  F64 dnew=dot(r,r);
142 // F64 d0=dnew;
143  while(i<N)
144  {
145  transformVector(*pA,d,q);
146 #if CG_TRACE
147  feLog("\n%d q=A*d >>> %.6G <<<\n",i,dnew);
148  feLog("d<%s>\n",c_print(d));
149  feLog("q<%s>\n",c_print(q));
150 #endif
151 
152  F64 alpha=dnew/dot(d,q);
153 #if CG_TRACE
154  feLog("alpha=dnew/dot(d,q) %.6G=%.6G/%.6G\n",alpha,dnew,dot(d,q));
155 #endif
156 
157 #if CG_TRACE
158  feLog("x<%s>\n",c_print(x));
159 #endif
160 // x=x+alpha*d;
161 
162 // temp=d;
163 // temp*=alpha;
164 // x+=temp;
165 
166  addScaled(x,alpha,d);
167 #if CG_TRACE
168  feLog("x+=alpha*d <%s>\n",c_print(temp));
169  feLog("x<%s>\n",c_print(x));
170 #endif
171 
172 #if CG_TRACE
173  feLog("r<%s>\n",c_print(r));
174 #endif
175 // r=r-alpha*q;
176 
177 // temp=q;
178 // temp*=alpha;
179 // r-=temp;
180 
181  addScaled(r,-alpha,q);
182 #if CG_TRACE
183  feLog("r-=alpha*q <%s>\n",c_print(temp));
184  feLog("r<%s>\n",c_print(r));
185 #endif
186 
187  F64 dold=dnew;
188  dnew=dot(r,r);
189  F64 beta=dnew/dold;
190 #if CG_TRACE
191  feLog("dnew=%.6G dold=%.6G beta=%.6G\n",dnew,dold,beta);
192 #endif
193 
194 // d=r+beta*d;
195 
196 // temp=d;
197 // temp*=beta;
198 // d=r;
199 // d+=temp;
200 
201  scaleAndAdd(d,beta,r);
202 
203 #if CG_TRACE
204  feLog("d=r+beta*d <%s>\n",c_print(temp));
205  feLog("d<%s>\n",c_print(d));
206 #endif
207 
208  if(magnitudeSquared(d)==0.0f)
209  {
210 // feX("ConjugateGradient::solve","direction lost its magnitude");
211 
212 #if CG_DEBUG
213  feLog("ConjugateGradient::solve direction lost its magnitude\n");
214 #endif
215  break;
216  }
217 
218  i++;
219 
220 #if CG_TRACE
221  feLog("ConjugateGradient::solve"
222  " ran %d/%d alpha %.6G |r| %.6G |d| %.6G\n",
223  i,N,alpha,magnitude(r),magnitude(d));
224 #endif
225 
226  if(magnitudeSquared(r)<m_threshold)
227  {
228 #if CG_DEBUG
229  feLog("ConjugateGradient::solve early solve %d/%d\n",i,N);
230 #endif
231  break;
232  }
233 
234  if(i==N)
235  {
236  feLog("ConjugateGradient::solve ran %d/%d\n",i,N);
237  }
238  }
239 
240 #if CG_DEBUG
241  feLog("\nx=<%s>\nA*x=<%s>\n",c_print(x),c_print(A*x));
242  feLog("\nb=<%s>\n",c_print(b));
243 #endif
244 
245 #if CG_VERIFY
246  BWORD invalid=FALSE;
247  for(U32 k=0;k<N;k++)
248  {
249  if(FE_INVALID_SCALAR(x[k]))
250  {
251  invalid=TRUE;
252  }
253  }
254  VECTOR Ax=A*x;
255  F64 distance=magnitude(Ax-b);
256  if(invalid || distance>1.0f)
257  {
258  feLog("ConjugateGradient::solve failed to converge (dist=%.6G)\n",
259  distance);
260  if(size(x)<100)
261  {
262  feLog(" collecting state ...\n");
263  feLog("A=\n%s\nx=<%s>\nA*x=<%s>\nb=<%s>\n",
264  c_print(A),c_print(x),
265  c_print(Ax),c_print(b));
266  }
267 // feX("ConjugateGradient::solve","failed to converge");
268  }
269 #endif
270 }
271 
272 } /* namespace ext */
273 } /* namespace fe */
274 
275 #endif /* __solve_ConjugateGradient_h__ */
kernel
Definition: namespace.dox:3
solve Ax=b for x
Definition: ConjugateGradient.h:38