Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,5 @@ dist

# IDE Files
.vscode/
.idea/
.idea/
.dccache
32 changes: 32 additions & 0 deletions src/linear_model/LinearRegression.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,38 @@ describe('LinearRegression', function () {
expect(roughlyEqual(lr.intercept as number, 0)).toBe(true)
}, 30000)

it('Works on arrays (small example) with custom callbacks', async function () {
let trainingHasStarted = false
const onTrainBegin = async (logs: any) => {
trainingHasStarted = true
console.log('training begins')
}
const lr = new LinearRegression({
modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] }
})
await lr.fit([[1], [2]], [2, 4])
expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true)
expect(roughlyEqual(lr.intercept as number, 0)).toBe(true)
expect(trainingHasStarted).toBe(true)
}, 30000)

it('Works on arrays (small example) with custom callbacks', async function () {
let trainingHasStarted = false
const onTrainBegin = async (logs: any) => {
trainingHasStarted = true
console.log('training begins')
}
const lr = new LinearRegression({
modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] }
})
await lr.fit([[1], [2]], [2, 4])

const serialized = await lr.toJSON()
const newModel = await fromJSON(serialized)
expect(tensorEqual(newModel.coef, tf.tensor1d([2]), 0.1)).toBe(true)
expect(roughlyEqual(newModel.intercept as number, 0)).toBe(true)
}, 30000)

it('Works on small multi-output example (small example)', async function () {
const lr = new LinearRegression()
await lr.fit(
Expand Down
14 changes: 10 additions & 4 deletions src/linear_model/LinearRegression.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import { SGDRegressor } from './SgdRegressor'
import { getBackend } from '../tf-singleton'
import { ModelFitArgs } from '../types'

/**
* LinearRegression implementation using gradient descent
Expand All @@ -39,6 +40,7 @@ export interface LinearRegressionParams {
* **default = true**
*/
fitIntercept?: boolean
modelFitOptions?: Partial<ModelFitArgs>
}

/*
Expand All @@ -50,7 +52,7 @@ Next steps:
/** Linear Least Squares
* @example
* ```js
* import {LinearRegression} from 'scikitjs'
* import { LinearRegression } from 'scikitjs'
*
* let X = [
* [1, 2],
Expand All @@ -60,13 +62,16 @@ Next steps:
* [10, 20]
* ]
* let y = [3, 5, 8, 8, 30]
* const lr = new LinearRegression({fitIntercept: false})
* const lr = new LinearRegression({ fitIntercept: false })
await lr.fit(X, y)
lr.coef.print() // probably around [1, 1]
* ```
*/
export class LinearRegression extends SGDRegressor {
constructor({ fitIntercept = true }: LinearRegressionParams = {}) {
constructor({
fitIntercept = true,
modelFitOptions
}: LinearRegressionParams = {}) {
let tf = getBackend()
super({
modelCompileArgs: {
Expand All @@ -80,7 +85,8 @@ export class LinearRegression extends SGDRegressor {
verbose: 0,
callbacks: [
tf.callbacks.earlyStopping({ monitor: 'mse', patience: 30 })
]
],
...modelFitOptions
},
denseLayerArgs: {
units: 1,
Expand Down
8 changes: 6 additions & 2 deletions src/linear_model/LogisticRegression.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import { SGDClassifier } from './SgdClassifier'
import { getBackend } from '../tf-singleton'
import { ModelFitArgs } from '../types'

// First pass at a LogisticRegression implementation using gradient descent
// Trying to mimic the API of scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
Expand All @@ -35,6 +36,7 @@ export interface LogisticRegressionParams {
C?: number
/** Whether or not the intercept should be estimator not. **default = true** */
fitIntercept?: boolean
modelFitOptions?: Partial<ModelFitArgs>
}

/** Builds a linear classification model with associated penalty and regularization
Expand Down Expand Up @@ -63,7 +65,8 @@ export class LogisticRegression extends SGDClassifier {
constructor({
penalty = 'l2',
C = 1,
fitIntercept = true
fitIntercept = true,
modelFitOptions
}: LogisticRegressionParams = {}) {
// Assume Binary classification
// If we call fit, and it isn't binary then update args
Expand All @@ -80,7 +83,8 @@ export class LogisticRegression extends SGDClassifier {
verbose: 0,
callbacks: [
tf.callbacks.earlyStopping({ monitor: 'loss', patience: 50 })
]
],
...modelFitOptions
},
denseLayerArgs: {
units: 1,
Expand Down