-
Notifications
You must be signed in to change notification settings - Fork 0
Train #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces train and market commands for K-Means clustering of market states (Bullish/Bearish/Sideway). Adds MarketStateService for model training and prediction, updates CLI registration, expands README and documentation, and adds required dependencies (pandas, numpy, scikit-learn, joblib).
Introduces RandomForest-based classifier training and forecasting for market state prediction. Updates CLI, documentation, and service layer to support train-classifier and forecast commands, including new workflow and options in README and guide.
Reviewer's GuideThis PR extends the CLI tool with unsupervised market-state labeling (K-Means clustering) and supervised forecasting (RandomForest), introduces new commands to train and apply these models, and updates documentation and wiring to integrate the features throughout the application. Sequence diagram for training and using market state modelssequenceDiagram
actor User
participant CLI
participant MarketStateService
participant KlineService
participant FileSystem
User->>CLI: run train -c <CRYPTO>
CLI->>MarketStateService: train_model(crypto_name)
MarketStateService->>KlineService: get_kline_data(crypto_name)
KlineService-->>MarketStateService: kline data
MarketStateService->>FileSystem: save trained model
MarketStateService-->>CLI: training summary
User->>CLI: run market -c <CRYPTO>
CLI->>MarketStateService: predict_market_state(crypto_name)
MarketStateService->>FileSystem: load trained model
MarketStateService->>KlineService: get_kline_data(crypto_name)
KlineService-->>MarketStateService: kline data
MarketStateService-->>CLI: latest state and distribution
Sequence diagram for classifier training and forecastingsequenceDiagram
actor User
participant CLI
participant MarketClassifierService
participant MarketStateService
participant FileSystem
User->>CLI: run train-classifier -c <CRYPTO>
CLI->>MarketClassifierService: train_classifier(crypto_name)
MarketClassifierService->>MarketStateService: get_labeled_feature_dataset(crypto_name)
MarketStateService->>FileSystem: load clustering model
MarketStateService-->>MarketClassifierService: labeled feature dataset
MarketClassifierService->>FileSystem: save classifier model
MarketClassifierService-->>CLI: training summary
User->>CLI: run forecast -c <CRYPTO>
CLI->>MarketClassifierService: forecast_next_state(crypto_name)
MarketClassifierService->>MarketStateService: get_labeled_feature_dataset(crypto_name)
MarketStateService->>FileSystem: load clustering model
MarketStateService-->>MarketClassifierService: labeled feature dataset
MarketClassifierService->>FileSystem: load classifier model
MarketClassifierService-->>CLI: prediction and probabilities
Class diagram for new and updated market state and classifier servicesclassDiagram
class MarketStateService {
+train_model(crypto_name, n_clusters, min_clusters, max_clusters)
+predict_market_state(crypto_name)
+prepare_feature_dataset(crypto_name)
+get_labeled_feature_dataset(crypto_name)
-_model_path(symbol, interval)
-_load_dataframe(crypto_name)
-_compute_features(frame)
-_assign_labels(cluster_returns)
-_resolve_cluster_count(X, forced_clusters, min_clusters, max_clusters)
}
class MarketClassifierService {
+train_classifier(crypto_name, test_size, random_state, n_estimators, max_depth)
+forecast_next_state(crypto_name)
-_model_path(symbol, interval)
-state_service: MarketStateService
}
class KlineService {
+get_kline_data(crypto_name)
}
MarketClassifierService --> MarketStateService
MarketStateService --> KlineService
class ModelTrainingError
class MarketModelNotFoundError
class ClassifierTrainingError
class ClassifierModelNotFoundError
MarketStateService ..> ModelTrainingError
MarketStateService ..> MarketModelNotFoundError
MarketClassifierService ..> ClassifierTrainingError
MarketClassifierService ..> ClassifierModelNotFoundError
Class diagram for new CLI commandsclassDiagram
class train_command {
+train_command(crypto, clusters, min_clusters, max_clusters)
}
class market_command {
+market_command(crypto, show_history)
}
class train_classifier_command {
+train_classifier_command(crypto, test_size, estimators, max_depth)
}
class forecast_command {
+forecast_command(crypto)
}
train_command --> MarketStateService
market_command --> MarketStateService
train_classifier_command --> MarketClassifierService
forecast_command --> MarketClassifierService
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there - I've reviewed your changes and they look great!
Prompt for AI Agents
Please address the comments from this code review:
## Individual Comments
### Comment 1
<location> `src/service/market_state_service.py:152-154` </location>
<code_context>
+ self, X: np.ndarray, forced_clusters: Optional[int], min_clusters: int, max_clusters: int
+ ) -> int:
+ sample_count = len(X)
+ if sample_count < 2:
+ raise ModelTrainingError("Need at least two samples to train KMeans.")
+
</code_context>
<issue_to_address>
**suggestion:** ModelTrainingError for sample_count < 2 may not be sufficient for clustering.
Increase the minimum sample requirement to ensure KMeans produces stable, non-trivial clusters.
```suggestion
sample_count = len(X)
min_required_samples = max(6, min_clusters * 2)
if sample_count < min_required_samples:
raise ModelTrainingError(
f"Need at least {min_required_samples} samples to train KMeans (got {sample_count})."
)
```
</issue_to_address>
### Comment 2
<location> `src/service/market_state_service.py:163-180` </location>
<code_context>
+
+ best_k = None
+ best_score = -1.0
+ for k in range(min_clusters, max_clusters + 1):
+ if sample_count <= k:
+ break
+ model = KMeans(n_clusters=k, n_init=10, random_state=42)
+ labels = model.fit_predict(X)
+ if len(set(labels)) == 1:
+ continue
+ score = silhouette_score(X, labels)
+ if score > best_score:
+ best_score = score
</code_context>
<issue_to_address>
**suggestion:** Silhouette score may be unreliable for small sample sizes.
For small datasets, consider implementing a minimum sample size check or alternative logic to avoid unreliable silhouette scores.
```suggestion
MIN_SILHOUETTE_SAMPLES = 10
best_k = None
best_score = -1.0
if sample_count < MIN_SILHOUETTE_SAMPLES:
# For small datasets, avoid silhouette_score and use min_clusters or raise an error
if sample_count <= min_clusters:
raise ModelTrainingError("Not enough samples for the minimum number of clusters.")
return min_clusters
for k in range(min_clusters, max_clusters + 1):
if sample_count <= k:
break
model = KMeans(n_clusters=k, n_init=10, random_state=42)
labels = model.fit_predict(X)
if len(set(labels)) == 1:
continue
score = silhouette_score(X, labels)
if score > best_score:
best_score = score
best_k = k
if best_k is None:
raise ModelTrainingError("Unable to determine an appropriate number of clusters.")
return best_k
```
</issue_to_address>
### Comment 3
<location> `src/service/market_classifier_service.py:64-70` </location>
<code_context>
+ scaler = StandardScaler()
+ X_scaled = scaler.fit_transform(X)
+
+ stratify = y if len(np.unique(y)) > 1 else None
+ X_train, X_test, y_train, y_test = train_test_split(
+ X_scaled, y, test_size=test_size, random_state=random_state, stratify=stratify
</code_context>
<issue_to_address>
**suggestion:** Stratification may fail if class distribution is highly imbalanced.
Handle potential exceptions from train_test_split when stratification fails, or add a warning for cases where stratification cannot be performed.
```suggestion
import warnings
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
stratify = y if len(np.unique(y)) > 1 else None
try:
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=test_size, random_state=random_state, stratify=stratify
)
except ValueError as e:
warnings.warn(
f"Stratified train_test_split failed due to: {e}. Proceeding without stratification.",
UserWarning
)
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=test_size, random_state=random_state, stratify=None
)
```
</issue_to_address>
### Comment 4
<location> `src/service/market_classifier_service.py:53-54` </location>
<code_context>
+ feature_cols = labeled_dataset["feature_columns"]
+ meta = labeled_dataset["meta"]
+
+ if frame["state"].isna().all():
+ raise ModelTrainingError("No market-state labels available. Train the clustering model first.")
+
</code_context>
<issue_to_address>
**suggestion (bug_risk):** All NaN states may indicate a deeper issue with clustering or feature engineering.
Log diagnostic details to identify whether the issue originates from feature computation or clustering.
```suggestion
if frame["state"].isna().all():
# Diagnostic logging for feature engineering and clustering issues
import logging
logger = logging.getLogger(__name__)
num_rows = len(frame)
num_state_nans = frame["state"].isna().sum()
feature_nan_counts = frame[feature_cols].isna().sum()
feature_nan_summary = feature_nan_counts.to_dict()
logger.error(
"All market-state labels are NaN. Diagnostics: "
f"num_rows={num_rows}, num_state_nans={num_state_nans}, "
f"feature_nan_summary={feature_nan_summary}"
)
logger.error(
"Feature columns summary statistics:\n%s",
frame[feature_cols].describe(include='all').to_string()
)
raise ModelTrainingError("No market-state labels available. Train the clustering model first.")
```
</issue_to_address>
### Comment 5
<location> `README.md:42` </location>
<code_context>
- Time range analysis
- Clean CLI interface with detailed logging
- Cross-platform support (Windows & Linux)
+- Trainable K-Means clustering to classify market states (Bullish/Bearish/Sideway)
+- RandomForest classifier to forecast the next market state
## Installation
</code_context>
<issue_to_address>
**issue (typo):** Replace 'Sideway' with 'Sideways' for correct terminology.
Update all instances of 'Sideway' to 'Sideways' for accuracy and consistency.
```suggestion
- Trainable K-Means clustering to classify market states (Bullish/Bearish/Sideways)
```
</issue_to_address>
### Comment 6
<location> `README.md:162` </location>
<code_context>
+./scripts/linux/run forecast -c <CRYPTO>
+```
+
+> Recommended workflow: fetch data with `dataset`, train the clustering model once with `train`, optionally train the classifier with `train-classifier`, then re-run `market` (current state) and `forecast` (next state) whenever new klines are pulled. Xem thêm hướng dẫn chi tiết tại `docs/market_state_guide.md`.
+
### Command Options
</code_context>
<issue_to_address>
**issue:** Translate the Vietnamese sentence to English for consistency.
Please translate the Vietnamese sentence to English to match the rest of the documentation and improve accessibility.
</issue_to_address>
### Comment 7
<location> `src/commands/forecast.py:18-27` </location>
<code_context>
def forecast_command(crypto: str = typer.Option(..., "--crypto", "-c", help="Crypto symbol to forecast (e.g., BTC)")):
logger = logging.getLogger(__name__)
try:
service = MarketClassifierService()
result = service.forecast_next_state(crypto)
logger.info("Forecast for %s (%s)", result["symbol"], result["interval"])
logger.info("=" * 50)
logger.info("Prediction timestamp: %s", result["prediction_timestamp"])
logger.info("Predicted next state: %s", result["predicted_state"])
if result["state_probabilities"]:
logger.info("Probabilities:\n%s", json.dumps(result["state_probabilities"], indent=2))
logger.info("Model: %s", result["model_path"])
except (
ClassifierModelNotFoundError,
ClassifierTrainingError,
MarketModelNotFoundError,
ModelTrainingError,
KlineNotFoundError,
ValueError,
) as exc:
logger.error(f"Forecast failed: {exc}")
except Exception as exc:
logger.error(f"Unexpected error: {exc}")
</code_context>
<issue_to_address>
**issue (code-quality):** Extract code out into function ([`extract-method`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/extract-method/))
</issue_to_address>
### Comment 8
<location> `src/commands/market.py:18` </location>
<code_context>
def market_command(
crypto: str = typer.Option(..., "--crypto", "-c", help="Crypto symbol to analyze (e.g., BTC)"),
show_history: bool = typer.Option(
False, "--show-history", help="Print cluster distribution in JSON for deeper analysis"
),
):
logger = logging.getLogger(__name__)
try:
service = MarketStateService()
summary = service.predict_market_state(crypto)
logger.info("Market state for %s (%s)", summary["symbol"], summary["interval"])
logger.info("=" * 50)
latest = summary["latest_state"]
logger.info("Latest timestamp: %s", latest["timestamp"])
logger.info("Closing price: $%s", f"{latest['close']:,.2f}")
logger.info("Cluster %d → %s", latest["cluster"], latest["state"])
logger.info("Feature snapshot:")
for name, value in latest["features"].items():
logger.info(" %s: %.6f", name, value)
logger.info("State distribution:")
for state, count in summary["state_distribution"].items():
logger.info(" %s: %d", state, count)
if show_history:
logger.info("Full distribution JSON:\n%s", json.dumps(summary["state_distribution"], indent=2))
except (MarketModelNotFoundError, ModelTrainingError, KlineNotFoundError, ValueError) as exc:
logger.error(f"Market analysis failed: {exc}")
except Exception as exc:
logger.error(f"Unexpected error: {exc}")
</code_context>
<issue_to_address>
**issue (code-quality):** Extract code out into function ([`extract-method`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/extract-method/))
</issue_to_address>
### Comment 9
<location> `src/service/market_state_service.py:68` </location>
<code_context>
def _load_dataframe(self, crypto_name: str) -> Dict[str, Any]:
data = self.kline_service.get_kline_data(crypto_name)
columns = [
"open_time",
"open",
"high",
"low",
"close",
"volume",
"close_time",
"quote_asset_volume",
"trade_count",
"taker_buy_base",
"taker_buy_quote",
"ignore",
]
frame = pd.DataFrame(data["klines"], columns=columns)
numeric_cols = [
"open",
"high",
"low",
"close",
"volume",
"quote_asset_volume",
"taker_buy_base",
"taker_buy_quote",
]
for col in numeric_cols:
frame[col] = frame[col].astype(float)
frame["open_time"] = pd.to_datetime(frame["open_time"], unit="ms")
frame["close_time"] = pd.to_datetime(frame["close_time"], unit="ms")
frame.sort_values("open_time", inplace=True)
frame.reset_index(drop=True, inplace=True)
return {"data": data, "frame": frame}
</code_context>
<issue_to_address>
**suggestion (code-quality):** Don't use `inplace` for methods that always create a copy under the hood ([`pandas-avoid-inplace`](https://docs.sourcery.ai/Reference/Default-Rules/suggestions/pandas-avoid-inplace/))
```suggestion
frame = frame.sort_values("open_time")
```
<br/><details><summary>Explanation</summary>Some `DataFrame` methods can never operate inplace. Their operation (like reordering rows) requires copying, so they create a copy even if you provide `inplace=True`.
For these methods, `inplace` doesn't bring a performance gain.
It's only a "syntactic sugar for reassigning the new result to the calling DataFrame/Series."
[PDEP-8](https://github.com/pandas-dev/pandas/pull/51466) suggests to deprecate the `inplace` option for these methods.
Best practice: Explicitly reassign the result to the caller `DataFrame`.
E.g.
```python
df = df.sort_values("language")
```</details>
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| sample_count = len(X) | ||
| if sample_count < 2: | ||
| raise ModelTrainingError("Need at least two samples to train KMeans.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: ModelTrainingError for sample_count < 2 may not be sufficient for clustering.
Increase the minimum sample requirement to ensure KMeans produces stable, non-trivial clusters.
| sample_count = len(X) | |
| if sample_count < 2: | |
| raise ModelTrainingError("Need at least two samples to train KMeans.") | |
| sample_count = len(X) | |
| min_required_samples = max(6, min_clusters * 2) | |
| if sample_count < min_required_samples: | |
| raise ModelTrainingError( | |
| f"Need at least {min_required_samples} samples to train KMeans (got {sample_count})." | |
| ) |
| best_k = None | ||
| best_score = -1.0 | ||
| for k in range(min_clusters, max_clusters + 1): | ||
| if sample_count <= k: | ||
| break | ||
| model = KMeans(n_clusters=k, n_init=10, random_state=42) | ||
| labels = model.fit_predict(X) | ||
| if len(set(labels)) == 1: | ||
| continue | ||
| score = silhouette_score(X, labels) | ||
| if score > best_score: | ||
| best_score = score | ||
| best_k = k | ||
|
|
||
| if best_k is None: | ||
| raise ModelTrainingError("Unable to determine an appropriate number of clusters.") | ||
|
|
||
| return best_k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Silhouette score may be unreliable for small sample sizes.
For small datasets, consider implementing a minimum sample size check or alternative logic to avoid unreliable silhouette scores.
| best_k = None | |
| best_score = -1.0 | |
| for k in range(min_clusters, max_clusters + 1): | |
| if sample_count <= k: | |
| break | |
| model = KMeans(n_clusters=k, n_init=10, random_state=42) | |
| labels = model.fit_predict(X) | |
| if len(set(labels)) == 1: | |
| continue | |
| score = silhouette_score(X, labels) | |
| if score > best_score: | |
| best_score = score | |
| best_k = k | |
| if best_k is None: | |
| raise ModelTrainingError("Unable to determine an appropriate number of clusters.") | |
| return best_k | |
| MIN_SILHOUETTE_SAMPLES = 10 | |
| best_k = None | |
| best_score = -1.0 | |
| if sample_count < MIN_SILHOUETTE_SAMPLES: | |
| # For small datasets, avoid silhouette_score and use min_clusters or raise an error | |
| if sample_count <= min_clusters: | |
| raise ModelTrainingError("Not enough samples for the minimum number of clusters.") | |
| return min_clusters | |
| for k in range(min_clusters, max_clusters + 1): | |
| if sample_count <= k: | |
| break | |
| model = KMeans(n_clusters=k, n_init=10, random_state=42) | |
| labels = model.fit_predict(X) | |
| if len(set(labels)) == 1: | |
| continue | |
| score = silhouette_score(X, labels) | |
| if score > best_score: | |
| best_score = score | |
| best_k = k | |
| if best_k is None: | |
| raise ModelTrainingError("Unable to determine an appropriate number of clusters.") | |
| return best_k |
| scaler = StandardScaler() | ||
| X_scaled = scaler.fit_transform(X) | ||
|
|
||
| stratify = y if len(np.unique(y)) > 1 else None | ||
| X_train, X_test, y_train, y_test = train_test_split( | ||
| X_scaled, y, test_size=test_size, random_state=random_state, stratify=stratify | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Stratification may fail if class distribution is highly imbalanced.
Handle potential exceptions from train_test_split when stratification fails, or add a warning for cases where stratification cannot be performed.
| scaler = StandardScaler() | |
| X_scaled = scaler.fit_transform(X) | |
| stratify = y if len(np.unique(y)) > 1 else None | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X_scaled, y, test_size=test_size, random_state=random_state, stratify=stratify | |
| ) | |
| import warnings | |
| scaler = StandardScaler() | |
| X_scaled = scaler.fit_transform(X) | |
| stratify = y if len(np.unique(y)) > 1 else None | |
| try: | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X_scaled, y, test_size=test_size, random_state=random_state, stratify=stratify | |
| ) | |
| except ValueError as e: | |
| warnings.warn( | |
| f"Stratified train_test_split failed due to: {e}. Proceeding without stratification.", | |
| UserWarning | |
| ) | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X_scaled, y, test_size=test_size, random_state=random_state, stratify=None | |
| ) |
| if frame["state"].isna().all(): | ||
| raise ModelTrainingError("No market-state labels available. Train the clustering model first.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (bug_risk): All NaN states may indicate a deeper issue with clustering or feature engineering.
Log diagnostic details to identify whether the issue originates from feature computation or clustering.
| if frame["state"].isna().all(): | |
| raise ModelTrainingError("No market-state labels available. Train the clustering model first.") | |
| if frame["state"].isna().all(): | |
| # Diagnostic logging for feature engineering and clustering issues | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| num_rows = len(frame) | |
| num_state_nans = frame["state"].isna().sum() | |
| feature_nan_counts = frame[feature_cols].isna().sum() | |
| feature_nan_summary = feature_nan_counts.to_dict() | |
| logger.error( | |
| "All market-state labels are NaN. Diagnostics: " | |
| f"num_rows={num_rows}, num_state_nans={num_state_nans}, " | |
| f"feature_nan_summary={feature_nan_summary}" | |
| ) | |
| logger.error( | |
| "Feature columns summary statistics:\n%s", | |
| frame[feature_cols].describe(include='all').to_string() | |
| ) | |
| raise ModelTrainingError("No market-state labels available. Train the clustering model first.") |
| - Time range analysis | ||
| - Clean CLI interface with detailed logging | ||
| - Cross-platform support (Windows & Linux) | ||
| - Trainable K-Means clustering to classify market states (Bullish/Bearish/Sideway) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (typo): Replace 'Sideway' with 'Sideways' for correct terminology.
Update all instances of 'Sideway' to 'Sideways' for accuracy and consistency.
| - Trainable K-Means clustering to classify market states (Bullish/Bearish/Sideway) | |
| - Trainable K-Means clustering to classify market states (Bullish/Bearish/Sideways) |
| ./scripts/linux/run forecast -c <CRYPTO> | ||
| ``` | ||
|
|
||
| > Recommended workflow: fetch data with `dataset`, train the clustering model once with `train`, optionally train the classifier with `train-classifier`, then re-run `market` (current state) and `forecast` (next state) whenever new klines are pulled. Xem thêm hướng dẫn chi tiết tại `docs/market_state_guide.md`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue: Translate the Vietnamese sentence to English for consistency.
Please translate the Vietnamese sentence to English to match the rest of the documentation and improve accessibility.
| service = MarketClassifierService() | ||
| result = service.forecast_next_state(crypto) | ||
|
|
||
| logger.info("Forecast for %s (%s)", result["symbol"], result["interval"]) | ||
| logger.info("=" * 50) | ||
| logger.info("Prediction timestamp: %s", result["prediction_timestamp"]) | ||
| logger.info("Predicted next state: %s", result["predicted_state"]) | ||
| if result["state_probabilities"]: | ||
| logger.info("Probabilities:\n%s", json.dumps(result["state_probabilities"], indent=2)) | ||
| logger.info("Model: %s", result["model_path"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): Extract code out into function (extract-method)
Summary by Sourcery
Extend the CLI tool with trainable market-state analysis and forecasting by adding K-Means clustering, state classification, RandomForest-based next-state prediction, corresponding commands, and documentation
New Features:
traincommand to fit and persist K-Means clustering models on historical kline datamarketcommand to label the latest data point as Bullish, Bearish, or Sideway using a trained clustering modeltrain-classifiercommand to train a supervised RandomForest classifier for next-state predictionforecastcommand to predict and output the next market state with probabilitiesEnhancements:
Documentation: