1. /* ===========================================================
  2. * JFreeChart : a free chart library for the Java(tm) platform
  3. * ===========================================================
  4. *
  5. * (C) Copyright 2000-2005, by Object Refinery Limited and Contributors.
  6. *
  7. * Project Info: http://www.jfree.org/jfreechart/index.html
  8. *
  9. * This library is free software; you can redistribute it and/or modify it under the terms
  10. * of the GNU Lesser General Public License as published by the Free Software Foundation;
  11. * either version 2.1 of the License, or (at your option) any later version.
  12. *
  13. * This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
  14. * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  15. * See the GNU Lesser General Public License for more details.
  16. *
  17. * You should have received a copy of the GNU Lesser General Public License along with this
  18. * library; if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330,
  19. * Boston, MA 02111-1307, USA.
  20. *
  21. * [Java is a trademark or registered trademark of Sun Microsystems, Inc.
  22. * in the United States and other countries.]
  23. *
  24. * ---------------
  25. * Regression.java
  26. * ---------------
  27. * (C) Copyright 2002-2005, by Object Refinery Limited.
  28. *
  29. * Original Author: David Gilbert (for Object Refinery Limited);
  30. * Contributor(s): -;
  31. *
  32. * $Id: Regression.java,v 1.2 2005/01/14 17:30:46 mungady Exp $
  33. *
  34. * Changes
  35. * -------
  36. * 30-Sep-2002 : Version 1 (DG);
  37. * 18-Aug-2003 : Added 'abstract' (DG);
  38. * 15-Jul-2004 : Switched getX() with getXValue() and getY() with getYValue() (DG);
  39. *
  40. */
  41. package org.jfree.data.statistics;
  42. import org.jfree.data.xy.XYDataset;
  43. /**
  44. * A utility class for fitting regression curves to data.
  45. */
  46. public abstract class Regression {
  47. /**
  48. * Returns the parameters 'a' and 'b' for an equation y = a + bx, fitted to the data using
  49. * ordinary least squares regression.
  50. * <p>
  51. * The result is returned as a double[], where result[0] --> a, and result[1] --> b.
  52. *
  53. * @param data the data.
  54. *
  55. * @return the parameters.
  56. */
  57. public static double[] getOLSRegression(double[][] data) {
  58. int n = data.length;
  59. if (n < 2) {
  60. throw new IllegalArgumentException("Not enough data.");
  61. }
  62. double sumX = 0;
  63. double sumY = 0;
  64. double sumXX = 0;
  65. double sumXY = 0;
  66. for (int i = 0; i < n; i++) {
  67. double x = data[i][0];
  68. double y = data[i][1];
  69. sumX += x;
  70. sumY += y;
  71. double xx = x * x;
  72. sumXX += xx;
  73. double xy = x * y;
  74. sumXY += xy;
  75. }
  76. double sxx = sumXX - (sumX * sumX) / n;
  77. double sxy = sumXY - (sumX * sumY) / n;
  78. double xbar = sumX / n;
  79. double ybar = sumY / n;
  80. double[] result = new double[2];
  81. result[1] = sxy / sxx;
  82. result[0] = ybar - result[1] * xbar;
  83. return result;
  84. }
  85. /**
  86. * Returns the parameters 'a' and 'b' for an equation y = a + bx, fitted to the data using
  87. * ordinary least squares regression.
  88. * <p>
  89. * The result is returned as a double[], where result[0] --> a, and result[1] --> b.
  90. *
  91. * @param data the data.
  92. * @param series the series (zero-based index).
  93. *
  94. * @return the parameters.
  95. */
  96. public static double[] getOLSRegression(XYDataset data, int series) {
  97. int n = data.getItemCount(series);
  98. if (n < 2) {
  99. throw new IllegalArgumentException("Not enough data.");
  100. }
  101. double sumX = 0;
  102. double sumY = 0;
  103. double sumXX = 0;
  104. double sumXY = 0;
  105. for (int i = 0; i < n; i++) {
  106. double x = data.getXValue(series, i);
  107. double y = data.getYValue(series, i);
  108. sumX += x;
  109. sumY += y;
  110. double xx = x * x;
  111. sumXX += xx;
  112. double xy = x * y;
  113. sumXY += xy;
  114. }
  115. double sxx = sumXX - (sumX * sumX) / n;
  116. double sxy = sumXY - (sumX * sumY) / n;
  117. double xbar = sumX / n;
  118. double ybar = sumY / n;
  119. double[] result = new double[2];
  120. result[1] = sxy / sxx;
  121. result[0] = ybar - result[1] * xbar;
  122. return result;
  123. }
  124. /**
  125. * Returns the parameters 'a' and 'b' for an equation y = ax^b, fitted to the data using
  126. * a power regression equation.
  127. * <p>
  128. * The result is returned as an array, where double[0] --> a, and double[1] --> b.
  129. *
  130. * @param data the data.
  131. *
  132. * @return the parameters.
  133. */
  134. public static double[] getPowerRegression(double[][] data) {
  135. int n = data.length;
  136. if (n < 2) {
  137. throw new IllegalArgumentException("Not enough data.");
  138. }
  139. double sumX = 0;
  140. double sumY = 0;
  141. double sumXX = 0;
  142. double sumXY = 0;
  143. for (int i = 0; i < n; i++) {
  144. double x = Math.log(data[i][0]);
  145. double y = Math.log(data[i][1]);
  146. sumX += x;
  147. sumY += y;
  148. double xx = x * x;
  149. sumXX += xx;
  150. double xy = x * y;
  151. sumXY += xy;
  152. }
  153. double sxx = sumXX - (sumX * sumX) / n;
  154. double sxy = sumXY - (sumX * sumY) / n;
  155. double xbar = sumX / n;
  156. double ybar = sumY / n;
  157. double[] result = new double[2];
  158. result[1] = sxy / sxx;
  159. result[0] = Math.pow(Math.exp(1.0), ybar - result[1] * xbar);
  160. return result;
  161. }
  162. /**
  163. * Returns the parameters 'a' and 'b' for an equation y = ax^b, fitted to the data using
  164. * a power regression equation.
  165. * <p>
  166. * The result is returned as an array, where double[0] --> a, and double[1] --> b.
  167. *
  168. * @param data the data.
  169. * @param series the series to fit the regression line against.
  170. *
  171. * @return the parameters.
  172. */
  173. public static double[] getPowerRegression(XYDataset data, int series) {
  174. int n = data.getItemCount(series);
  175. if (n < 2) {
  176. throw new IllegalArgumentException("Not enough data.");
  177. }
  178. double sumX = 0;
  179. double sumY = 0;
  180. double sumXX = 0;
  181. double sumXY = 0;
  182. for (int i = 0; i < n; i++) {
  183. double x = Math.log(data.getXValue(series, i));
  184. double y = Math.log(data.getYValue(series, i));
  185. sumX += x;
  186. sumY += y;
  187. double xx = x * x;
  188. sumXX += xx;
  189. double xy = x * y;
  190. sumXY += xy;
  191. }
  192. double sxx = sumXX - (sumX * sumX) / n;
  193. double sxy = sumXY - (sumX * sumY) / n;
  194. double xbar = sumX / n;
  195. double ybar = sumY / n;
  196. double[] result = new double[2];
  197. result[1] = sxy / sxx;
  198. result[0] = Math.pow(Math.exp(1.0), ybar - result[1] * xbar);
  199. return result;
  200. }
  201. }