SMBC Group GREEN×DATA Challenge 2024 チュートリアル

- このチュートリアルは初めてコンペティションに取り組む人向けです。
SMBCグループ GREEN×DATA チャレンジ2024
SMBCグループ GREEN×DATA チャレンジ2024にご参加いただき、ありがとうございます。
このチュートリアルでは、予測アルゴリズムを作成し、投稿ファイルを作る方法をGoogle Colaboratory(以下、Colab)で説明します。
Colabは、ブラウザでPythonプログラミングができる無料サービスで、利用にはGoogleアカウントが必要です。
Colabでは、準備された分析環境をそのまま使えるので、初心者でも環境構築をスキップできます。
チュートリアルの後は、このコードを基に精度を上げてみてください。
また、このチュートリアルで使うNotebookは「データ」ページからダウンロードできます。
皆様のたくさんの投稿をお待ちしております!
目次
- イントロダクション
- 前処理
- 予測モデルの学習・検証
- 予測結果の投稿
- 精度改善のヒント
1.イントロダクション¶
1-1.Colabの使用方法¶
以下では、Colabの起動からGoogle Drive内のデータ読み込むための準備の手順を説明します。
Colabの起動
- Googleアカウントにログイン
- ブラウザで Colab にアクセス
- 「新しいノートブック」をクリックして、新しいColab環境を立ち上げる
Google Driveのディレクトリ構成
今回はマイドライブの中に、以下のような構造でデータを置くことを前提とします。
マイドライブ/
├── data/
│ ├── train.csv
│ ├── test.csv
│ ├── sample_submission.csv
│ └── tutorial.ipynb
Google Driveへの接続方法
- Colabノートブックを開いた状態で、以下のコードを実行して自分のGoogle Driveへ接続
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
- 「Google Driveに接続」をクリックし、Googleアカウントを選択
- 接続が成功すると、「Mounted at /content/drive」というメッセージが表示される
これで、自分のGoogle Drive に接続できました!
以下のコードでマイドライブの中にあるフォルダを確認することもできます。
ls /content/drive/MyDrive/
1-2.ライブラリの読み込み¶
ライブラリを読み込むことで、pythonがある特定の分野に特化した処理を行うことができるようになります。
データ分析を行うための基本的なライブラリを読み込みましょう。
今使用するライブラリの全てを読み込まなくても、必要に応じて都度ライブラリを読み込みながら進めることもできます。
# 基本的なライブラリは、Colabではデフォルトで利用可能
import pandas as pd # データを表のように扱うライブラリ
import numpy as np # 数値計算を速くするライブラリ
import seaborn as sns # きれいなグラフを簡単に作るライブラリ
import matplotlib.pyplot as plt # グラフを作る基本的なライブラリ
%matplotlib inline
from sklearn.model_selection import train_test_split # データを訓練用と検証用に分ける
from sklearn.metrics import mean_squared_log_error # 評価の計算を行うライブラリ
import lightgbm as lgb # 予測モデルに関するライブラリ
import warnings
warnings.simplefilter('ignore') # 不要な警告を表示しない
/usr/local/lib/python3.10/dist-packages/dask/dataframe/__init__.py:42: FutureWarning: Dask dataframe query planning is disabled because dask-expr is not installed. You can install it with `pip install dask[dataframe]` or `conda install dask`. This will raise in a future version. warnings.warn(msg, FutureWarning)
1-3.データの読み込み
それではデータを読み込んでみましょう。
データを読み込むには pd.read_csv() を使います。
マイドライブは/content/drive/の中にあるので、trainデータまでのパスは以下のようになります。
# 予測モデルを訓練するためのデータセット
train = pd.read_csv('/content/drive/MyDrive/data/train.csv', index_col=0)
# 予測モデルに推論(予測)させるデータセット
test = pd.read_csv('/content/drive/MyDrive/data/test.csv', index_col=0)
これでデータを読み込むことができました!次章ではどのようなデータがそれぞれに格納されているのかを確認します。
2.前処理¶
AIモデルが正確に予測できるということは、入力データの特徴量と予測したい結果(目的変数)の関係をよく理解している状態です。しかし、AIモデルがどんなに優秀でも、すべての入力データで自動的にその関係を見つけられるわけではありません。
つまり、AIモデルがその関係を理解しやすくするためには、モデルを訓練する前に入力データを適切に加工(前処理)する必要があります。
適切な加工を実施するために、以下の2ステップを説明いたします。
- データのことを調べて、どのような処理が必要か考えてみる
- 考えた処理を実際に実装してみる
これらのステップで「入力データから目的変数の情報をどれだけ効果的に引き出せるか」は予測性能の高いAIモデルを作るために最も重要な処理の一つです。
そこで、まず始めに訓練用データの中身を確かめてみましょう。
def describe_data(data, name):
print(f"\n{name} のデータ概要:\n")
display(data.info())
display(data.describe())
describe_data(train, 'train')
describe_data(test, 'test')
train のデータ概要:Index: 4655 entries, 0 to 4654 Data columns (total 21 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 FacilityName 4655 non-null object 1 Latitude 4553 non-null float64 2 Longitude 4553 non-null float64 3 LocationAddress 4476 non-null object 4 City 4655 non-null object 5 State 4655 non-null object 6 ZIP 4655 non-null object 7 County 4585 non-null object 8 FIPScode 4582 non-null float64 9 PrimaryNAICS 4655 non-null int64 10 SecondPrimaryNAICS 379 non-null float64 11 IndustryType 4654 non-null object 12 TRI_Air_Emissions_10_in_lbs 1635 non-null float64 13 TRI_Air_Emissions_11_in_lbs 1635 non-null float64 14 TRI_Air_Emissions_12_in_lbs 1635 non-null float64 15 TRI_Air_Emissions_13_in_lbs 1635 non-null float64 16 GHG_Direct_Emissions_10_in_metric_tons 3953 non-null float64 17 GHG_Direct_Emissions_11_in_metric_tons 4284 non-null float64 18 GHG_Direct_Emissions_12_in_metric_tons 4395 non-null float64 19 GHG_Direct_Emissions_13_in_metric_tons 4507 non-null float64 20 GHG_Direct_Emissions_14_in_metric_tons 4655 non-null float64 dtypes: float64(13), int64(1), object(7) memory usage: 800.1+ KB
None
Latitude | Longitude | FIPScode | PrimaryNAICS | SecondPrimaryNAICS | TRI_Air_Emissions_10_in_lbs | TRI_Air_Emissions_11_in_lbs | TRI_Air_Emissions_12_in_lbs | TRI_Air_Emissions_13_in_lbs | GHG_Direct_Emissions_10_in_metric_tons | GHG_Direct_Emissions_11_in_metric_tons | GHG_Direct_Emissions_12_in_metric_tons | GHG_Direct_Emissions_13_in_metric_tons | GHG_Direct_Emissions_14_in_metric_tons | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 4553.000000 | 4553.000000 | 4582.000000 | 4655.000000 | 379.000000 | 1.635000e+03 | 1635.000000 | 1635.000000 | 1635.000000 | 3.953000e+03 | 4.284000e+03 | 4.395000e+03 | 4.507000e+03 | 4.655000e+03 |
mean | 37.502474 | -93.132567 | 29594.314055 | 354701.002578 | 320797.759894 | 6.179064e+04 | 43853.462331 | 53770.293062 | 56007.086034 | 2.485158e+05 | 1.612068e+05 | 3.159905e+05 | 1.834006e+05 | 2.525133e+05 |
std | 5.739955 | 15.680084 | 16565.382504 | 145549.947021 | 36638.391132 | 1.344983e+05 | 55988.952050 | 93977.128341 | 109863.242856 | 5.225110e+05 | 2.641831e+05 | 7.395843e+05 | 4.026237e+05 | 4.854669e+05 |
min | 13.394900 | -166.553496 | 1001.000000 | 111419.000000 | 212113.000000 | 2.279515e+03 | 34.450512 | 2076.649083 | 4656.522747 | 1.089413e+02 | 8.167190e-01 | 2.009966e+02 | 2.689283e+01 | 5.598067e+02 |
25% | 33.471611 | -98.498664 | 17105.500000 | 221112.000000 | 322121.000000 | 2.525215e+04 | 25319.052605 | 22764.672344 | 22905.881327 | 5.123975e+04 | 3.733935e+04 | 4.822487e+04 | 3.547781e+04 | 4.174869e+04 |
50% | 37.873810 | -90.446667 | 29034.000000 | 325180.000000 | 325199.000000 | 3.204187e+04 | 31765.719617 | 29667.091832 | 29305.093678 | 7.440347e+04 | 6.119749e+04 | 7.242684e+04 | 5.744658e+04 | 6.789793e+04 |
75% | 41.152783 | -82.604399 | 45050.000000 | 486210.000000 | 331111.000000 | 4.019123e+04 | 38329.096290 | 36984.431956 | 36755.017845 | 1.676409e+05 | 1.418608e+05 | 2.279668e+05 | 1.195544e+05 | 2.109168e+05 |
max | 70.490861 | 144.807727 | 78010.000000 | 928110.000000 | 562212.000000 | 1.251231e+06 | 478366.459152 | 743548.788013 | 989230.802586 | 3.900222e+06 | 2.698567e+06 | 6.837260e+06 | 4.330236e+06 | 4.614103e+06 |
test のデータ概要:Index: 2508 entries, 4655 to 7162 Data columns (total 20 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 FacilityName 2508 non-null object 1 Latitude 2452 non-null float64 2 Longitude 2452 non-null float64 3 LocationAddress 2395 non-null object 4 City 2508 non-null object 5 State 2508 non-null object 6 ZIP 2508 non-null object 7 County 2463 non-null object 8 FIPScode 2463 non-null float64 9 PrimaryNAICS 2508 non-null int64 10 SecondPrimaryNAICS 184 non-null float64 11 IndustryType 2508 non-null object 12 TRI_Air_Emissions_10_in_lbs 874 non-null float64 13 TRI_Air_Emissions_11_in_lbs 874 non-null float64 14 TRI_Air_Emissions_12_in_lbs 874 non-null float64 15 TRI_Air_Emissions_13_in_lbs 874 non-null float64 16 GHG_Direct_Emissions_10_in_metric_tons 2130 non-null float64 17 GHG_Direct_Emissions_11_in_metric_tons 2297 non-null float64 18 GHG_Direct_Emissions_12_in_metric_tons 2371 non-null float64 19 GHG_Direct_Emissions_13_in_metric_tons 2435 non-null float64 dtypes: float64(12), int64(1), object(7) memory usage: 411.5+ KB
None
Latitude | Longitude | FIPScode | PrimaryNAICS | SecondPrimaryNAICS | TRI_Air_Emissions_10_in_lbs | TRI_Air_Emissions_11_in_lbs | TRI_Air_Emissions_12_in_lbs | TRI_Air_Emissions_13_in_lbs | GHG_Direct_Emissions_10_in_metric_tons | GHG_Direct_Emissions_11_in_metric_tons | GHG_Direct_Emissions_12_in_metric_tons | GHG_Direct_Emissions_13_in_metric_tons | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 2452.000000 | 2452.000000 | 2463.000000 | 2508.000000 | 184.000000 | 8.740000e+02 | 874.000000 | 874.000000 | 8.740000e+02 | 2.130000e+03 | 2.297000e+03 | 2.371000e+03 | 2.435000e+03 |
mean | 37.570662 | -92.363851 | 30404.922452 | 354306.630383 | 318199.500000 | 6.487513e+04 | 40970.199510 | 49969.490551 | 5.099979e+04 | 2.567106e+05 | 1.601432e+05 | 3.124102e+05 | 1.892745e+05 |
std | 5.758665 | 15.263442 | 16241.454602 | 147092.064447 | 48626.377425 | 1.532609e+05 | 50742.809398 | 96607.925446 | 1.092213e+05 | 5.239007e+05 | 2.753633e+05 | 7.448045e+05 | 4.284521e+05 |
min | 13.463579 | -158.125764 | 1001.000000 | 211111.000000 | 115114.000000 | 4.405829e+03 | 5343.548842 | 3144.557141 | 1.221438e+03 | 1.342263e+02 | 5.203845e+01 | 3.003952e+03 | 2.226384e+02 |
25% | 33.139870 | -97.758778 | 17200.000000 | 221112.000000 | 322130.000000 | 2.465727e+04 | 25406.683554 | 23070.544500 | 2.206082e+04 | 5.025370e+04 | 3.543934e+04 | 4.570376e+04 | 3.507816e+04 |
50% | 38.047855 | -90.131700 | 31067.000000 | 325155.000000 | 325199.000000 | 3.137226e+04 | 30861.833946 | 29153.614011 | 2.811996e+04 | 7.313435e+04 | 5.862897e+04 | 6.998558e+04 | 5.659606e+04 |
75% | 41.346859 | -81.908218 | 47032.000000 | 486210.000000 | 331111.000000 | 3.898384e+04 | 37976.075558 | 35656.637976 | 3.498953e+04 | 1.850996e+05 | 1.138890e+05 | 2.150169e+05 | 1.113874e+05 |
max | 71.292071 | 144.678216 | 78030.000000 | 928110.000000 | 562910.000000 | 1.374503e+06 | 481505.852810 | 898432.732339 | 1.010902e+06 | 3.927869e+06 | 2.120411e+06 | 6.112087e+06 | 4.292472e+06 |
データの中身を確認しました。
今回目的変数は施設ごとの2014年のGHG排出量(GHG_Direct_Emissions_14_in_metric_tonsReviewer_Score)です。
特徴量は、施設の場所や業種といったカテゴリデータと観測値の数字データがあります。
また、Non-Null Countを見るといくつか欠損値が含まれていることもわかります。
前処理の方法は以下が挙げられます。
- カテゴリ型(str)
- 場所を表すデータは場所別に特徴がないかを確認しグルーピング
- IndustryTypeで特徴に差があるかを確認してエンコーディング
- 数値型(int, float)
- 欠損値の処理
- 分布を確認して標準化など
それでは実際に確認していきましょう。
2-1.カテゴリ型(str)の処理¶
場所を表す特徴量¶
緯度(Latitude)経度(Longitude)の情報から地図上にマッピングしてみましょう。
それだけでなく、目的変数である2014年のGHG排出量の大小で色分けもしてみます。
難しいかつ、使用頻度が高くないコードですので、覚えずに生成AIを活用してください!
import folium
import matplotlib.cm as cm
import matplotlib.colors as colors
from IPython.display import display
# 必要なデータを抽出し、新しいデータフレームを作成
data = train[['Latitude', 'Longitude', 'GHG_Direct_Emissions_14_in_metric_tons']].copy()
# 緯度、経度、排出量のいずれかが欠損している行を削除
data.dropna(subset=['Latitude', 'Longitude', 'GHG_Direct_Emissions_14_in_metric_tons'], inplace=True)
# 各列のデータ型を float に変換
for col in ['Latitude', 'Longitude', 'GHG_Direct_Emissions_14_in_metric_tons']:
data[col] = data[col].astype(float)
# 地図の中心をデータの緯度と経度の平均位置に設定
map_center = [data['Latitude'].mean(), data['Longitude'].mean()]
m = folium.Map(location=map_center, zoom_start=5)
# 排出量の最大値と最小値を取得
max_emission = data['GHG_Direct_Emissions_14_in_metric_tons'].max()
min_emission = data['GHG_Direct_Emissions_14_in_metric_tons'].min()
# カラーマップを設定(排出量が少ない地点は黄色、多い地点は赤色で表示)
colormap = cm.get_cmap('YlOrRd')
normalize = colors.Normalize(vmin=min_emission, vmax=max_emission)
# 各地点に対して、排出量に応じた色の円マーカーを作成し地図に追加
for idx, row in data.iterrows():
# 排出量に基づいて色を設定
color = colors.rgb2hex(colormap(normalize(row['GHG_Direct_Emissions_14_in_metric_tons'])))
# 円マーカーを作成
folium.CircleMarker(
location=[row['Latitude'], row['Longitude']], # 緯度・経度
radius=5, # 円のサイズ
popup=f"Emissions: {row['GHG_Direct_Emissions_14_in_metric_tons']}", # ポップアップに排出量を表示
color=color, # 枠の色
fill=True, # 円を塗りつぶす
fill_color=color # 塗りつぶしの色
).add_to(m)
# 地図を表示(Jupyter Notebook 上で表示可能)
display(m)
ブラウザ上では表示できませんが、マウス操作で動かせるマップに、データ内で相対的に数値が高いところは赤くそうでないところは黄色く表示されます。
ぜひご自身のノートブックで表示して考察してみてください。
業種を表す特徴量¶
業種とGHG排出量はなんらかの関係があると仮説を立てることができます。業種を表すIndustryTypeを確認してみましょう。
train['IndustryType'].value_counts()
count | |
---|---|
IndustryType | |
Petroleum and Natural Gas Systems | 960 |
Power Plants | 943 |
Waste | 859 |
Other | 723 |
Minerals | 246 |
Chemicals | 209 |
Metals | 199 |
Other,Waste | 86 |
Pulp and Paper | 81 |
Natural Gas and Natural Gas Liquids Suppliers,Petroleum and Natural Gas Systems | 65 |
Pulp and Paper,Waste | 57 |
Petroleum Product Suppliers,Refineries | 37 |
Chemicals,Petroleum Product Suppliers,Refineries | 28 |
Chemicals,Suppliers of CO2 | 24 |
Other,Suppliers of CO2 | 19 |
Metals,Waste | 11 |
Refineries | 10 |
Other,Suppliers of CO2,Waste | 10 |
Chemicals,Industrial Gas Suppliers | 7 |
Injection of CO2,Petroleum and Natural Gas Systems,Suppliers of CO2 | 6 |
Pulp and Paper,Suppliers of CO2,Waste | 6 |
Minerals,Waste | 6 |
Petroleum and Natural Gas Systems,Suppliers of CO2 | 5 |
Chemicals,Refineries,Suppliers of CO2 | 5 |
Chemicals,Waste | 5 |
Other,Power Plants | 5 |
Chemicals,Petroleum Product Suppliers,Refineries,Suppliers of CO2 | 5 |
Injection of CO2,Petroleum and Natural Gas Systems | 3 |
Power Plants,Waste | 3 |
Chemicals,Refineries | 3 |
Petroleum Product Suppliers,Petroleum and Natural Gas Systems | 3 |
Natural Gas and Natural Gas Liquids Suppliers,Power Plants | 3 |
Import and Export of Equipment Containing Fluorintaed GHGs,Other | 2 |
Chemicals,Petroleum Product Suppliers | 1 |
Import and Export of Equipment Containing Fluorintaed GHGs,Industrial Gas Suppliers,Other | 1 |
Petroleum Product Suppliers,Power Plants,Refineries | 1 |
Injection of CO2,Power Plants,Suppliers of CO2 | 1 |
Chemicals,Petroleum Product Suppliers,Refineries,Waste | 1 |
Power Plants,Chemicals,Coal-based Liquid Fuel Supply,Suppliers of CO2 | 1 |
Natural Gas and Natural Gas Liquids Suppliers,Other,Petroleum and Natural Gas Systems | 1 |
Injection of CO2,Other,Suppliers of CO2,Waste | 1 |
Chemicals,Other | 1 |
Chemicals,Industrial Gas Suppliers,Minerals | 1 |
Chemicals,Industrial Gas Suppliers,Waste | 1 |
Chemicals,Petroleum Product Suppliers,Power Plants,Refineries | 1 |
Natural Gas and Natural Gas Liquids Suppliers,Petroleum and Natural Gas Systems,Power Plants | 1 |
Chemicals,Power Plants | 1 |
Power Plants,Suppliers of CO2 | 1 |
Metals,Power Plants | 1 |
Chemicals,Other,Petroleum and Natural Gas Systems,Waste | 1 |
Pulp and Paper,Suppliers of CO2 | 1 |
Petroleum and Natural Gas Systems,Power Plants | 1 |
Chemicals,Other,Petroleum Product Suppliers,Power Plants,Refineries | 1 |
思いのほか、多くの要素がありました。
今回は実施しませんが、これらのカテゴリを数値化する手法としてエンコーディングというものがあります。
参考
エンコーディングとは何か
エンコーディングとは、AIモデルが理解できるように、文字やカテゴリなどの非数値データを数値データに変換することです。これにより、AIモデルはデータを適切に処理し、学習することができます。
以下は代表的なエンコーディング手法である「ラベルエンコーディング」と「ワンホットエンコーディング」を整理した表です。
項目 | ラベルエンコーディング | ワンホットエンコーディング |
---|---|---|
手法 | カテゴリを整数値に変換する。 例:「赤」=0、「青」=1、「緑」=2 |
各カテゴリごとに新しい二進数の特徴量を作成する。 例:「赤」=[1,0,0]、「青」=[0,1,0]、「緑」=[0,0,1] |
メリット | - シンプルで実装が容易 |
- カテゴリ間の関係性を持たない - モデルがカテゴリを平等に扱う |
デメリット | - カテゴリに順位がないのに数値が割り当てられるため、モデルが誤解する可能性がある | - データの次元数が増加し、高次元になると計算量が増える |
選択基準 | - カテゴリに明確な順位や大小関係がある場合 | - カテゴリに順位がない場合 - カテゴリ数が多すぎない場合 |
2-2.数値型(int, float)の処理¶
続いて数値型の前処理です。
まずは可視化して分布を確認します。
numerical_features = [
'TRI_Air_Emissions_10_in_lbs', 'TRI_Air_Emissions_11_in_lbs', 'TRI_Air_Emissions_12_in_lbs', 'TRI_Air_Emissions_13_in_lbs',
'GHG_Direct_Emissions_10_in_metric_tons', 'GHG_Direct_Emissions_11_in_metric_tons', 'GHG_Direct_Emissions_12_in_metric_tons', 'GHG_Direct_Emissions_13_in_metric_tons',
]
num_bins = 100
min_value = train[numerical_features].min().min()
max_value = train[numerical_features].max().max()
bins = np.linspace(min_value, max_value, num_bins)
for col in numerical_features:
plt.figure(figsize=(12, 3))
sns.histplot(train[col], bins=bins, kde=True)
plt.title(f'{col}の分布')
plt.xlabel(col)
plt.ylabel('頻度')
plt.show()
print()
すべての数値データが小さな値に偏っており、各特徴量の数値の範囲に多少差があります。
このような偏った分布では、平均値が大きくなりやすいため、欠損値は中央値で補います。
# 欠損値の補完
for col in numerical_features:
train[col].fillna(train[col].median(), inplace=True)
test[col].fillna(train[col].median(), inplace=True)
以上で前処理は終了です。
今回は数値データのみを利用してAIモデルを構築します。
3. 予測モデルの学習・検証¶
本章ではモデルの学習と検証を行います。
そこでまず、次の操作を行います。
- 学習用データセットに含まれる特徴量と目的変数の分離
- 学習用データセットを学習用と検証用のデータセットへ分割
# 訓練用データセットからターゲットを分離する
X = train[numerical_features]
y = train['GHG_Direct_Emissions_14_in_metric_tons']
# 投稿のためのテストデータも同様の処理を行う
test_X = test[numerical_features]
# 訓練用データセットを訓練用と検証用に分割する
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=42)
# 結果の確認(データフレームの形状)
print(f"X_train: {X_train.shape}, X_valid: {X_valid.shape}")
print(f"y_train: {y_train.shape}, y_valid: {y_valid.shape}")
X_train: (3724, 8), X_valid: (931, 8) y_train: (3724,), y_valid: (931,)
3-1. 今回使用する予測モデルの紹介¶
次に利用する予測モデルを用意します。表データ用の予測モデルを構築する手法は数多く提案されていますが、本チュートリアルでは、その中でも最も有名で強力な手法の一つであるLightGBMを利用します。
かいつまんだ説明にはなりますが、この手法では以下の工夫により、高速に予測能力の高いモデルの構築を実現しています。
- 内部に持つ複数の予測器のアンサンブルによる精度の高い予測(キーワード:勾配ブースティング法, アンサンブル学習)
- 情報の少ない列をまとめて扱うことによる高速な訓練(キーワード:Exclusive Feature Bundling)
次節ではLightGBMを使った学習を行います。
3-2. 学習¶
# LightGBM用のデータセットに変換
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_valid, label=y_valid, reference=train_data)
# LightGBMのパラメータ設定
params = {
'objective': 'regression', # 回帰タスク
'metric': 'rmse', # RMSEで評価
'boosting_type': 'gbdt', # 勾配ブースティング木
'learning_rate': 0.1, # 学習率
'verbose': -1, # 詳細な出力を抑制
'random_state':42 # 乱数の固定
}
# モデルの訓練
model = lgb.train(
params,
train_data,
valid_sets=[valid_data],
)
これでモデルの学習が完了しました。
あらかじめ分けておいた検証用データで訓練済みモデルの予測性能を確認しましょう。
本コンペティションで使用する評価指標はRMSLEですのでscikit-learnからmean_squared_log_errorを呼び出して評価します。
# 検証用データセットに対する予測
y_pred = model.predict(X_valid)
# 負の数値を0に変換
y_pred = [0 if val <= 0 else val for val in y_pred]
# RMSLEで評価
from sklearn.metrics import mean_squared_log_error
rmsle_score = mean_squared_log_error(y_valid, y_pred, squared=False)
print(f"RMSLE: {rmsle_score:.4f}")
RMSLE: 1.0106
評価結果(モデルの精度)が出力されました。
初めて参加される方は、是非この値を超えるようなモデルの作成を目指してください!
4. 予測・結果の投稿¶
最後に学習したモデルをテストデータに対して予測をし、投稿ファイルを作成します。
test_pred = model.predict(test_X)
test_pred
array([ 58161.9139442 , 398883.2977091 , 47403.62892573, ..., 210242.59604802, 286929.71415799, 838362.78518116])
テストデータの予測ができました。
見本のsample_submission.csvを用いて投稿ファイルを作成しましょう。
# 投稿ファイル作成
submit = pd.read_csv('/content/drive/MyDrive/data/sample_submission.csv', header=None)
submit[1] = test_pred
submit.to_csv('/content/drive/MyDrive/data/submission_tutorial.csv', header=None, index=False)
# 投稿ファイルの中身を確認
submit.head()
0 | 1 | |
---|---|---|
0 | 4655 | 58161.913944 |
1 | 4656 | 398883.297709 |
2 | 4657 | 47403.628926 |
3 | 4658 | 13513.476993 |
4 | 4659 | 79085.397812 |
これで本チュートリアルは終了です。
早速作成したsubmission_tutorial.csvを投稿してみましょう。
コンペサイトの投稿ボタンから投稿可能です。
5. 精度改善のヒント¶
投稿が完了しましたが、コンペティションはこれで終わりではありません。次に精度を改善する必要があります。
ぜひチャレンジしてみてください。
今後の改善点の例:
- 他の特徴量から、新たな特徴量を作成する
- 欠損値を別の方法で埋める
- LightGBM以外のアルゴリズムや、アンサンブル(複数のモデルの予測結果の平均などを取る処理)を行う
