{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 6-量子アニーリングによる機械学習 (QBoost)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/OpenJij/OpenJijTutorial/blob/master/source/ja/006-Machine_Learning_by_QA.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "このチュートリアルでは、量子アニーリングの最適化の応用の一例として機械学習を取り上げます。 \n", "前半では、PyQUBOとOpenjijを利用したクラスタリングを行います。 \n", "後半では、PyQUBOとD-Waveのサンプラーを用いたQboostというアンサンブル学習を行います。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## クラスタリング\n", "\n", "クラスタリングとは与えられたデータを$n$個のクラスターに分けるというタスクです($n$は外部から与えられているとします)。簡単のため、今回はクラスター数が2を考えましょう。\n", "\n", "### 必要なライブラリのインポート\n", "\n", "機械学習のライブラリであるscikit-learnなどをインポートします。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# ライブラリのインポート\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from sklearn import cluster\n", "import pandas as pd\n", "from scipy.spatial import distance_matrix \n", "from pyqubo import Array, Constraint, Placeholder, solve_qubo\n", "import openjij as oj\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 人工データの生成\n", "\n", "今回は人工的に二次元平面上の明らかに線形分離可能なデータを生成しましょう。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data = []\n", "label = []\n", "for i in range(100):\n", " # [0, 1]の乱数を生成\n", " p = np.random.uniform(0, 1)\n", " # ある条件を満たすときをクラス1、満たさないときを-1\n", " cls =1 if p>0.5 else -1\n", " # ある座標を中心とする正規分布に従う乱数を作成\n", " data.append(np.random.normal(0, 0.5, 2) + np.array([cls, cls]))\n", " label.append(cls)\n", "# DataFrameとして整形\n", "df1 = pd.DataFrame(data, columns=[\"x\", \"y\"], index=range(len(data)))\n", "df1[\"label\"] = label" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# データセットの可視化\n", "df1.plot(kind='scatter', x=\"x\", y=\"y\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "今回は、以下のハミルトニアンを最小化することでクラスタリングを行います。\n", "\n", "$$\n", "H = - \\sum_{i, j} \\frac{1}{2}d_{i,j} (1 - \\sigma _i \\sigma_j)\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$i, j$はサンプルの番号、$d_{i,j}$は2つのサンプル間の距離、$\\sigma_i=\\{-1,1\\}$は2つのクラスターのどちらかに属しているかを表すスピン変数です。 \n", "このハミルトニアンの和の各項は \n", "\n", "- $\\sigma_i = \\sigma_j $のとき、0\n", "- $\\sigma_i \\neq \\sigma_j $のとき、$d_{i,j}$ \n", "\n", "となります。右辺のマイナスに注意すると、ハミルトニアン全体では「異なるクラスに属しているサンプル同士の距離を最大にする$\\{\\sigma _1, \\sigma _2 \\ldots \\}$の組を選びなさい」という問題に帰着することがわかります。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PyQUBOによるクラスタリング\n", "\n", "まずは、PyQUBOで上のハミルトニアンを定式化します。そして`solve_qubo`を用いてシミュレーテッドアニーリング(SA)を行います。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def clustering_pyqubo(df):\n", " # 距離行列の作成\n", " d_ij = distance_matrix(df, df)\n", " # スピン変数の設定\n", " spin = Array.create(\"spin\", shape= len(df), vartype=\"SPIN\")\n", " # 全ハミルトニアンの設定\n", " H = - 0.5* sum(\n", " [d_ij[i,j]* (1 - spin[i]* spin[j]) for i in range(len(df)) for j in range(len(df))]\n", " )\n", " # コンパイル\n", " model = H.compile()\n", " # QUBOに変換\n", " qubo, offset = model.to_qubo()\n", " # SAで解を求める\n", " raw_solution = solve_qubo(qubo, num_reads=10)\n", " # 解を見やすい形にデコード\n", " decoded_solution, broken, energy= model.decode_solution(raw_solution, vartype=\"SPIN\")\n", " # ラベルを抽出\n", " labels = [decoded_solution[\"spin\"][idx] for idx in range(len(df))]\n", " return labels, energy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "実行をおよび解の確認を行います。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "label [0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]\n", "energy -13861.844537404875\n" ] } ], "source": [ "labels, energy =clustering_pyqubo(df1[[\"x\", \"y\"]])\n", "print(\"label\", labels)\n", "print(\"energy\", energy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可視化をしてみましょう。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWW0lEQVR4nO3df4hdZ53H8c83kwQSUhAnA4W2uSOsLFuKIBlkF/9YcPtHLbJFQVBuy2CFYFihggsq+Tt/CUJBQQL+yHYGRVCp7FZqC7uUlVWclCLpxkqRJgbLOsn+0YbI5td3/zi5zcydc+49557nnPM857xfcEnn5t57njltv/PM9/k+38fcXQCAdO3regAAgHoI5ACQOAI5ACSOQA4AiSOQA0Di9ndx0aNHj/rq6moXlwaAZJ07d+6Ku69MP99JIF9dXdXW1lYXlwaAZJnZxbznSa0AQOII5ACQOAI5ACSOQA4AiSOQA0DiCOQA4re5Ka2uSvv2ZX9ubnY9oqh0Un4IAKVtbkonTkjXr2dfX7yYfS1J43F344oIM3IAcTt16l4Qn7h+PXsekgjkAGJ36VK15weIQA4gbseOVXt+gAjkAOJ2+rR0+PDu5w4fzp6HJAI5gNiNx9KZM9JoJJllf545w0LnDgRyoI/6Vq43HktvvSXduZP9SRDfhfJDoG8o1xscZuRA31CuNzgEcqBvKNcbHAI50DeU6w0OgRzoG8r1BodADvQN5XqDU7tqxcwekvQvku6XdEfSGXd/tu7nAqhhPCZwD0iI8sNbkr7s7q+a2X2SzpnZS+7+3wE+GwAwR+3Uiru/7e6v3v3ndyVdkPRA3c8FAJQTNEduZquSPizp1yE/FwBQLFggN7Mjkn4s6Uvu/k7O358wsy0z29re3g51WQAYvCCB3MwOKAvim+7+k7zXuPsZd19z97WVlZUQlwUAKEAgNzOT9B1JF9z9G/WHBACoIsSM/KOSnpL0MTN77e7j8QCfC2Co+ta9sWG1yw/d/T8lWYCxAADdGxfAzk4AcaF7Y2UEcgBxoXtjZQRyYKhizUPTvbEyAjkwRJM89MWLkvu9PHQMwZzujZURyIEhijkPTffGyszdW7/o2tqab21ttX5dAHft25fNxKeZZQccI0pmds7d16afZ0YODFFRvvn97+8mbx5rvj4RBHJgiPLy0AcPSu+8sztv/uST0tGjzQbWmPP1iSCQA31Sdmabl4e+7z7p5s29r716tdnAGnO+PhHkyIG+mN4RKWWz7rILhUV584nRSHrrrdrDLH1d8vV7kCMH+q7uzHZenXZTG3KoG6+NQA70Rd0dkXl5852aCqzUjddGIAf6ou7MdpI3X17e+3dNBtY6deNUu0gikAP9EWJmOx5LV65IGxvtbsgZj7P8+5072Z9lgzjVLpJY7AT6ZXMzy4lfupTNxE+f7u+OyNXVLHhPa2pRNgJFi50EcgBpGmC1C1UrAPqFapf3EMgBpIlql/cQyAGkiS6J76l9ZicAdGY8HmTgnsaMHAASRyAHgMQRyAEgcQRyAEgcgRxA3OinMhdVKwDiNd1jfdJPRaJaZQdm5ADixelBpRDIAcSrbo/1gSCQA4gX/VRKIZADiFdePxUz6fHHZ79vYAukBHKgb8oEsVQC3Xgsra9nwXvCXTp7tnjMQzxwwt1bfxw/ftwBNGBjw/3wYfcshGWPw4ez56u8JuR4RiN3s+zPRa4xGu0e6+QxGoV5fUIkbXlOTOVgCaBPypya09bJOtOlg1KWJqnaobDoAAkpG/P0aUg9PnCCgyWAIShT5dFWJUio0sGihU2z/PTJABdICeRAn5QJYm0FulA/MIoWPKdn3ZMfEgM8cIJADvTJrCA2WeC8eHH34uHO14RU9wfGZLxPPSUdOiQtL987QKIo1XLp0iAPnAgSyM3su2b2ZzM7H+LzACyoKIhJ9yo5pCwQToJ5U4Guzsx4uvLk6lXpL3+Rnnsuy+OPRvnvm/yQGI+z1925k/3Z4yAuhZuRf1/SY4E+C0AdeUEsL1/tfm+Bs0yg21myePRo9phVvjj5obK8fO+5Q4fKfQ/z8usppU/aKPXMK2VZ5CFpVdL5Mq+l/BBomVl+SZ5ZuffnlSyWKV9ctNSxzHhDlDY2LXCpp5ouPzSzVUn/6u6PFPz9CUknJOnYsWPHL+aVPwFoRt2Sw6L3z/usRa/bVolk0wJ/H52XH7r7GXdfc/e1lZWVti4LQMpPRRw8KF27Vu5X/jKVJnmvWbRyJaXUySwtlXpStQIMwfQi6PLyvUXEMtvYy1Sa5L1m0cqVvPEeOpRVsMTcUmBaS6WeBHJgKHYugh45It28ufvvZ23WyZsh77Rztrxzce/atWzmX/TaMuN97rmsYqXsD52YtPWbRV7ivOpD0g8kvS3ppqTLkj4/6/UsdgId2tgoXrSctfi5c3FxeTl7TC805i3uHTiQ/9qyUu+dEnBRVvRaAZDb/2SnuouJTSxS9rh3SlWdL3YCiEBeffZE0a/8Veqgm1jcG2DvlKoI5MCQzAqoebs7q/b2biLo9qWCpUEEcmBIigLqaJS/u7NqB8Mmgm5fKlgaRCAHhqRqoK2aKmmqYVUfKlgaRCAH+iovt1010C6SKmmyYVWoHudlpHIcnqT9XQ8AQAOmq1MmM1cpC6xlg+vp0/mn/HSVn27rUIx59y8yzMiBPgo1c+2yt3fejLjsbwh1Z9NtzvwDoI4c6KPUa6+LzvtcX5fOnp19DmiIs0IjvX/UkQNDMm/mGnv+t2hG/MIL839DKHrv+nr57ze12vW87Z5NP9iiDzRsVh/skD2yT550X1rKPmNpKfs6hDr904veW+X7DdxHPBQVbNEnkAN9VdTjI1TvkpMn8z8nRDCvM8ai91b9rAgPrigK5OTIgb7b3MzSDZcuZamBogMiquZ/9++Xbt/e+/zSknTr1mJjnaiT557XT2YilfWCHciRA0OUt8V+cujytKr537wgPuv5KupUy0y/d2kp/3Wx5rsXQCAHUlR2sbLo0OXpYD6rNrzoWkUBsuj5qupsLNr53rNn+9+rJS/f0vSDHDlQQ5WFuFkLf2Xyv7Ou1WSOvK7p/PbJk9HluxchFjuBnqiyEFh3YXPe+5uqWpll3iJkpBUnIRQFchY7gdRU2axSd3NMbBtjynw/TRxuEQkWO4G+qLJZpe4W+9g2xpTZOt9WP5aIEMiB1FRtRVtn0TC2Qx3KBOnYfvi0gEAOpKbNRlZdNs3KUyZIx/bDpwXkyAGko2zOf3oT1OnTUbafrYocOYD0jcdZ86tJrfrSUvb1dJBu8nCLCBHIAaRjczPb4DPZPXr7dvZ1bN0bW0YgB5COxA58aAuBHEhN7L3EmzTA0sIyCORASvKaYA3pFPkBlhaWQSAHYjY9+37mmWGnFgZYWlgGgRyIVd7s++rV/Nf2IbVQJmUUW117JKgjB2JV1DMkT+p9REIcmDwA1JEDqSk7y+5DaoFqlFoI5ECsihbwlpf7l1qgGqUWAjkQq6KFvWef7d+uRapRaiGQA7Ea0sIe1Si17O96AABmGI/7GbinTb7HHja6agMzciRhyJsZB2Ngja5CYkaO6E1Xpk02M0r8vw5IgWbkZvaYmb1hZm+a2VdDfCYwUVSZtr7OzByQAgRyM1uS9C1JH5f0sKTPmtnDdT8XmCiqQLt9e1htRoAiIWbkH5H0prv/wd1vSPqhpCcCfC4gaXYFGntGgDCB/AFJf9zx9eW7z+1iZifMbMvMtra3twNcFkORV5m2E3tGMHQhArnlPLengYu7n3H3NXdfW1lZCXBZDMWknHpyutc09oxg6EIE8suSHtrx9YOS/hTgc9GRGEv9xuPsRC/2jAB7hQjkv5H0QTP7gJkdlPQZST8L8LnoQMznFgxpoyNQRe1A7u63JH1R0ouSLkj6kbu/Xvdz0Y2um9DN+22APSPAXkE2BLn7C5JeCPFZ6FaXTejY+AMshi362KXLJnRd/zYApIpAjl26bEJHS2pgMQRy7BJiQXHRqhdaUgOLIZBjjzoLinWqXmhJDSyGQI6g6uS5y/w2EGONO9A1c9+zCbNxa2trvrW11fp10bx9+7KZ+DSzbIZfBwetY+jM7Jy7r00/z4x8wJqY3TaZ56aqBchHIB+opnZwNpnnpqoFyJdMICc3GlaZ2e0i97zJbfRUtQAF3L31x/Hjx72KjQ33w4fds7lj9jh8OHseizHbfT8nD7Ps72O85zGOCWiTpC3PialJzMjJje4W4reTebPbru75rO+NpllAgbzo3vSj6ox83uyxTzY23Eej7HsbjfbONkPMSjc23JeX997PnZ/TxT1nxg3MpoIZeRKBfDTKDyqj0QJ3ImJlAlnde5F3DSkL7CGvs4ih/HsGFlUUyJNIrQxlx1+ZdEbdyo28a0jSkSO7UxRd3HOqUoDFJBHIh5IbLRPIinLb7uXy5WWDZRf3nKoU9FqTpXd50/SmH1VTK0NRJrVQlBopm1OOOX1Bjhy9Feg/bqWcWhmKMumMnTPlPPMqS2JOUw3lNy8MUNNlYHnRvekHM/Ji86pWdlq0sqTKNfqM+4DWBCoDEzPy9tRJhVVpIbtoTplzL+M+ZBo91PACEIE8sDYDRKg0yRDbH7DJDK1qOqeZN01v+tHn1Erbi4l10wNDXWBMbZMZaaAeCPAvUQWpFfqRB9ZkP+4mHD0qXb269/nRKEu79NXqavbb0rQYv2/6sGOCfuQtSakWenMzP4hL4TbhxJq2ibl6ZxppIMxDIA8stQBRJMQPnnnrBV0G+fFYWl+Xlpayr5eWsq9jnOGy4xVz5eVbmn70OUfunk4+syhPLIUZ86z1gq5z811fv4qYN3GhXSJHjmlFeeLlZenKlfqfP2u94NixbnPU5MiRInLk2KMoDfTss2E+f9Z6Qdfpgq6vXwU7XjHPIAJ5rAtuXQsRIGbd21nrBV0vCnd9/arYxIWZ8vItTT/azJGnlAtNTZne5kXrBV3/e+n6+sAilPLBEnWwUNSconsruR88mAX0WQu+XS8Kd319oKqiQN77xc7UNuikpOje5mFxDqhvsIudqeVCU1LlHl6/Lj355L08OusWQDi9D+QpbdBJTd69nefiRelzn5OefprOg0AovQ/klG41Z3Jvl5erve/mTenGjd3PseUcWFzvA7lE6VaTxuNs89DGxr0flsvL0oED1T+rbA03aRlgt1qB3Mw+bWavm9kdM9uTgMdw7PxheeWK9L3vFR9HV6RMzp0DIYC96s7Iz0v6lKRXAowFEQg1250E9o2NvXn0Awekgwf3vufatfkNtegECOy1v86b3f2CJJlZmNGgU9M9PSazXWnxdNTkfadOZamTY8fuLTQ/88zuNrpXr2bX++UvpbNn88eR0tZ6oC1B6sjN7D8k/bO7lyoOp2lWnMo0ktrc3BuUFw3yRddbWpJu384fh5ROsysgtIXryM3sZTM7n/N4ouIATpjZlpltbW9vV3krWjJvths6P110vbwgPnk95aTAXnMDubs/6u6P5Dyer3Ihdz/j7mvuvraysrL4iAeorSqNeZunQueni643Oewh7/WUkwJ7DaL8MAVFwTr0LHjRboVS+Px00fVOnJg9DspJgSl5DVjKPiR9UtJlSf8n6X8kvVjmfX0/IaiqWZ34Qjb9KtPxb1YjqSYakM3qjkhDK2A3DbVpVgpmLTJeuhSu6VfdU3E4qQbo1mCbZqVgVsoiZNOvuqmRpg+iALAYAnkEZgXrkFUaIX4o1MlPsysTaAaBPAKzgnXIKo2uS/fYlQk0gxx5JEJutInhOnk45AOohxx5x+blhtsqqWvyOvO+x0VTO+TVgTnySlmafgyt/DCv7M/M/eTJrkcWTtnSxqoHHnNIMnCPhnr4cgyK6q/N+hOQytaYV60Pb+PwbGrWkYqiQE6OvAWzDiluqiFV25rKfzedV6c2HikhR96hWTngphpSta2pQ66bPjybShr0AYG8BadPZzPIPE01pGpbU6WNTZdM0t8cfUAgb8F4LH3hC3uDeZMNqSbaqvhYtN69TDVPk90Om57xA63IS5w3/RjaYudEFw2pYq74aHt8efc/9nsE7CSqVuLWREBpo+KjjjbHN+v+UrWCVBQFcqpWIhK6aqVqxUfbVTNt7vSs2/kRiEFR1Uqtw5cR1ngcNnAeO5YfvPLyv00cvBxyfHWxqIk+Y7Gzx8pWfGxuSuvr7VfNtNnEi0VN9BmBvMfKVHxMZuKzDjzucnyhdN35EWgSOfKBK8odTywvS1eutDacRqW8cxaQ2NmJAvNm3O++m87u0nk4tBl9RSDvsTKbgebliG/cSGd3aVNoo4vYEcg71lSQKNu7JS93PG3IlR2p98DBMBDIO9RkkCjbu2XngmORrio7YpgJp94DB8PAYmeHmtyksshmm5hausYyFo6nQ0xY7IxQk5tUFqmbbrMccJ5YZsLUnyMFBPIONRkkFq2bjqWyI5admNSfIwUE8g41GSRiml0vIpaZcOr3EcNAjrxjbFLJF0uOHIgJTbMiFbpRVl9M7gk/5ID5COSIFj/kgHLIkQNA4gjkAJA4AjkAJI5ADgCJI5ADQOII5ACQOAI5ACSuViA3s6+b2e/M7Ldm9lMze1+gcQEASqo7I39J0iPu/iFJv5f0tfpDAgBUUSuQu/sv3P3W3S9/JenB+kMCAFQRMkf+tKSfF/2lmZ0wsy0z29re3g54WQAYtrm9VszsZUn35/zVKXd//u5rTkm6JanwMC53PyPpjJR1P1xotACAPebOyN39UXd/JOcxCeLrkj4haexd9MRFVGI4ZxMYmlrdD83sMUlfkfT37n593uvRb9M9xCeHSUt0MQSaVDdH/k1J90l6ycxeM7NvBxgTEhXLOZvA0NSakbv7X4UaCNIXyzmbwNCwsxPBxHLOJjA0BHIEw4nzQDcI5AiGE+eBbnBmJ4LinE2gfczIASBxBHIASByBHAASRyAHgMQRyAEgcdZFnysz25Z0sfULl3dU0pWuBxEp7s1s3J9i3JtiZe/NyN1Xpp/sJJDHzsy23H2t63HEiHszG/enGPemWN17Q2oFABJHIAeAxBHI853pegAR497Mxv0pxr0pVuvekCMHgMQxIweAxBHIASBxBPICZvZ1M/udmf3WzH5qZu/rekyxMLNPm9nrZnbHzCgnU3Z+rZm9YWZvmtlXux5PTMzsu2b2ZzM73/VYYmJmD5nZv5vZhbv/Pz2z6GcRyIu9JOkRd/+QpN9L+lrH44nJeUmfkvRK1wOJgZktSfqWpI9LeljSZ83s4W5HFZXvS3qs60FE6JakL7v730j6W0n/tOh/NwTyAu7+C3e/dffLX0l6sMvxxMTdL7j7G12PIyIfkfSmu//B3W9I+qGkJzoeUzTc/RVJ/9v1OGLj7m+7+6t3//ldSRckPbDIZxHIy3la0s+7HgSi9YCkP+74+rIW/B8Sw2Rmq5I+LOnXi7x/0CcEmdnLku7P+atT7v783decUvYr0GabY+tamXuD91jOc9T1ohQzOyLpx5K+5O7vLPIZgw7k7v7orL83s3VJn5D0Dz6wgvt59wa7XJb00I6vH5T0p47GgoSY2QFlQXzT3X+y6OeQWilgZo9J+oqkf3T3612PB1H7jaQPmtkHzOygpM9I+lnHY0LkzMwkfUfSBXf/Rp3PIpAX+6ak+yS9ZGavmdm3ux5QLMzsk2Z2WdLfSfo3M3ux6zF16e6i+BclvahswepH7v56t6OKh5n9QNJ/SfprM7tsZp/vekyR+KikpyR97G6Mec3MHl/kg9iiDwCJY0YOAIkjkANA4gjkAJA4AjkAJI5ADgCJI5ADQOII5ACQuP8HA9JzzEdZHScAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for idx, label in enumerate(labels):\n", " if label:\n", " plt.scatter(df1.loc[idx][\"x\"], df1.loc[idx][\"y\"], color=\"b\") \n", " else:\n", " plt.scatter(df1.loc[idx][\"x\"], df1.loc[idx][\"y\"], color=\"r\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Openjijのソルバーを用いたクラスタリング\n", "\n", "次はQUBOの定式化にPyQUBOを用いて、ソルバー部分をOpenJijにしてクラスタリングを行います。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def clustering_openjij(df):\n", " # 距離行列の作成\n", " d_ij = distance_matrix(df, df)\n", " # スピン変数の設定\n", " spin = Array.create(\"spin\", shape= len(df), vartype=\"SPIN\")\n", " # 全ハミルトニアンの設定\n", " H = - 0.5* sum(\n", " [d_ij[i,j]* (1 - spin[i]* spin[j]) for i in range(len(df)) for j in range(len(df))]\n", " )\n", " # コンパイル\n", " model = H.compile()\n", " # QUBOに変換\n", " qubo, offset = model.to_qubo()\n", " # OpenJijのSAをサンプラーに設定\n", " sampler = oj.SASampler(num_reads=10, num_sweeps=100)\n", " # サンプラーで解を求める\n", " response = sampler.sample_qubo(qubo)\n", " # 生データの抽出\n", " raw_solution = dict(zip(response.indices, response.states[np.argmin(response.energies)]))\n", " # 解を見やすい形にデコード\n", " decoded_solution, broken, energy= model.decode_solution(raw_solution, vartype=\"SPIN\")\n", " # ラベルを抽出\n", " labels = [int(decoded_solution[\"spin\"][idx] ) for idx in range(len(df))]\n", " return labels, sum(response.energies)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "実行および解の確認を行います。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "label [1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1]\n", "energy -138618.44537404884\n" ] } ], "source": [ "labels, energy =clustering_openjij(df1[[\"x\", \"y\"]])\n", "print(\"label\", labels)\n", "print(\"energy\", energy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "こちらも、可視化をしてみましょう。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWVklEQVR4nO3dX4hlV5XH8d/qSjekSUCsLhCS9C1hZJgQBOlCZvBhwPEhBpmgICg3oTBCYzNCBAdU+rmfBCGgIA3+6UkViqASmYnECDPIyChWJEgyrRIk3TaGsbrnIQktk/6z5uH0TVXdOufe82efc/Y+5/uBS6du33vPrpNk1a61117b3F0AgHQd6XsAAIBmCOQAkDgCOQAkjkAOAIkjkANA4u7q46InTpzw9fX1Pi4NAMl64YUXrrr72vzzvQTy9fV17ezs9HFpAEiWmV3Ke57UCgAkjkAOAIkjkANA4gjkAJA4AjkAJI5ADiB629vS+rp05Ej25/Z23yOKSy/lhwBQ1va2dPq0dP169vWlS9nXkjSd9jeumDAjBxC1s2f3gvjM9evZ88gQyAFE7fLlas+PEYEcQNROnqz2/BgRyAFE7dw56fjxg88dP549jwyBHEDUplPp/HlpMpHMsj/Pn2ehcz8COTBAQyvXm06lV1+Vbt/O/iSIH0T5ITAwlOuNDzNyYGAo1xsfAjkwMJTrjQ+BHBgYyvXGh0AODAzleuNDIAcGhnK98WlctWJmD0j6F0nvknRb0nl3f6rp5wKobzolcI9JiPLDm5I+7+6/NrN7Jb1gZs+7+38H+GwAwBKNUyvu/pq7//rOP78h6aKk+5p+LgCgnKA5cjNbl/Q+Sb8M+bkAgGLBArmZ3SPp+5I+5+6v5/z9aTPbMbOd3d3dUJcFgNELEsjN7KiyIL7t7j/Ie427n3f3DXffWFtbC3FZAIACBHIzM0nfkHTR3b/SfEgAgCpCzMg/IOlxSR80sxfvPB4J8LkARmpo3Rvb1rj80N3/U5IFGAsA0L2xBnZ2AogK3RurI5ADiArdG6sjkAMjFWsemu6N1RHIgRGa5aEvXZLc9/LQMQRzujdWRyAHRijmPDTdG6szd+/8ohsbG76zs9P5dQFkjhzJZuLzzLIDjhEnM3vB3Tfmn2dGDoxQUb75ne/sJ28ea74+FQRyYITy8tDHjkmvv34wb/7YY9KJE+0G1pjz9akgkAMDUnZmm5eHvvde6caNw6+9dq3dwBpzvj4V5MiBgZjfESlls+6yC4VFefOZyUR69dXGwyx9XfL1h5EjBwau6cx2WZ12WxtyqBtvjkAODETTHZF5efP92gqs1I03RyAHBqLpzHaWN19dPfx3bQbWJnXjVLtkCOTAQISY2U6n0tWr0tZWtxtyptMs/377dvZn2SBOtUuGxU5gQLa3s5z45cvZTPzcueHuiFxfz4L3vLYWZWNQtNhJIAeQpDFWu1C1AmBQqHbZQyAHkCSqXfYQyAEkiS6Jexqf2QkAfZlOxxm45zEjB4DEEcgBIHEEcgBIHIEcABJHIAcQNfqpLEfVCoBozfdYn/VTkahW2Y8ZOYBocXpQOQRyANFq2mN9LAjkAKJFP5VyCOQAopXXT8VMeuSRxe8b2wIpgRwYmDJBLJVAN51Km5tZ8J5xly5cKB7zKA+ccPfOH6dOnXIA4W1tuR8/7p6FsOxx/Hj2fJXXhBzPZOJulv1Z5xqTycGxzh6TSZjXp0TSjufEVA6WAAakzKk5XZ2sM186KGVpkqodCosOkJCyMc+fhjTkAyc4WAIYgTJVHl1VgoQqHSxa2DTLT5+McYGUQA4MSJkg1lWgC/UDo2jBc37WPfshMcYDJwjkwIAsCmKzBc5Llw4uHu5/TUhNf2DMxvv449Ldd0urq3sHSBSlWi5fHueBE0ECuZl908z+bGYvhfg8APUUBTFpr5JDygLhLJi3FeiazIznK0+uXZP+8hfp6aezPP5kkv++2Q+J6TR73e3b2Z9DDuJSuBn5tyU9HOizADSQF8Ty8tXuewucZQLd/pLFEyeyx6LyxdkPldXVvefuvrvc97Asv55S+qSTUs+8UpY6D0nrkl4q81rKD4FumeWX5JmVe39eyWKZ8sW6pY5lxhuitLFtoUs91Xb5oZmtS/pXd3+o4O9PSzotSSdPnjx1Ka/+CUArmpYcFr1/2WfVvW5XJZJtC/199F5+6O7n3X3D3TfW1ta6uiwA5acijh2T3nyz3K/8ZSpN8l5Tt3IlpdTJIl2VelK1AozA/CLo6ureImKZbexlKk3yXlO3ciVvvHffnVWwxNxSYF5XpZ4EcmAk9i+C3nOPdOPGwb9ftFknb4a83/7Z8v7FvTffzGb+Ra8tM96nn84qVsr+0IlJZ79Z5CXOqz4kfUfSa5JuSLoi6dOLXs9iJ9Cfra3iRctFi5/7FxdXV7PH/EJj3uLe0aP5ry0r9d4pIRdlRa8VAHn9T/ZrupjYxiLlkHunVNX7YieA/uXVZ88U/cpfpQ66jcW9MfZOqYpADozIooCat7uzam/vNoLuUCpY2kQgB0akKKBOJvm7O6t2MGwj6A6lgqVNBHJgRKoG2qqpkrYaVg2hgqVNBHJgoPJy21UDbZ1USZsNq0L1OC8jlePwJOmuvgcAILz56pTZzFXKAmvZ4HruXP4pP33lp7vaKbns/sWGGTkwQKFmrn329s6bEZf9DaHpbLrLmX8I1JEDA5R67XXReZ+bm9KFC4vPAQ1xVmis9486cmBEls1cY8//Fs2In312+W8IRe/d3Cz//SZXu5633bPtB1v0gXYt6oMdskf2mTPuKyvZZ6ysZF+H0KR/etF7q3y/ofuIh6KCLfoEcmCginp8hOpdcuZM/ueECOZNxlj03qqfFePBFUWBnBw5MHDb21m64fLlLDVQdEBE1fzvXXdJt24dfn5lRbp5s95YZ5rkuZf1k5npO99dBzlyYITyttjPDl2eVzX/mxfEFz1fRZNqmfn3rqzkvy7afHcNBHIgQWUXK4sOXZ4P5otqw4uuVRQgi56vqsnGov3vvXBhBL1a8vItbT/IkQP1VVmIW7TwVyb/u+habebIm5rPb585E1++uw6x2AkMQ5WFwKYLm8ve31bVyiLLFiFjrTgJoSiQs9gJJKbKZpWmm2Ni2xhT5vtp43CLWLDYCQxElc0qTbfYx7YxpszW+a76scSEQA4kpmor2iaLhrEd6lAmSMf2w6cLBHIgMV02suqzaVaeMkE6th8+XSBHDiAZZXP+85ugzp2Ls/1sVeTIASRvOs2aX81q1VdWsq/ng3Sbh1vEiEAOIBnb29kGn9nu0Vu3sq9j697YNQI5gGSkduBDVwjkQGJi7yXepjGWFpZBIAcSktcEa0ynyI+xtLAMAjkQsfnZ95NPjju1MMbSwjII5ECk8mbf167lv3YIqYUyKaPY6tpjQR05EKminiF5Uu8jEuLA5DGgjhxITNlZ9hBSC1SjNEMgByJVtIC3ujq81ALVKM0QyIFIFS3sPfXU8HYtUo3SDIEciNSYFvaoRmnmrr4HAKDYdDrMwD1v9j0OsdFVF5iRIw1j3s44EmNrdBUSM3LEb742bbadUeL/dkCBZuRm9rCZ/c7MXjGzL4b4TOBtRbVpm5vMzAEFCORmtiLpa5I+LOlBSZ80swebfi7wtqIatFu3xtVoBCgQYkb+fkmvuPsf3P0tSd+V9GiAzwUyi2rQ2DUCBAnk90n6476vr9x57gAzO21mO2a2s7u7G+CyGI282rT92DWCkQsRyC3nuUMNXNz9vLtvuPvG2tpagMtiNGYF1bPzveaxawQjFyKQX5H0wL6v75f0pwCfi77EWOo3nWZnerFrBDgkRCD/laT3mNm7zeyYpE9I+lGAz0UfYj65YExbHYEKGgdyd78p6bOSnpN0UdL33P3lpp+LnvTdhm7ZbwPsGgEOCbIhyN2flfRsiM9Cz/psQ8fGH6AWtujjoD7b0PX92wCQKAI5DuqzDR1NqYFaCOQ4KMSCYt2qF5pSA7UQyHFYkwXFJlUvNKUGaiGQI6wmee4yvw3EWOMO9MzcD23CbN3Gxobv7Ox0fl104MiRbCY+zyyb4TfBUesYOTN7wd035p9nRj5mbcxu28xzU9UC5CKQj1VbOzjbzHNT1QLkSieQkxsNq8zsts49b3MbPVUtQD537/xx6tQpr2Rry/34cfds7pg9jh/Pnkc9Zgfv5+xhlv19jPc8xjEBHZK04zkxNY0ZObnRg0L8drJsdtvXPV/0vdE0C8iXF93bflSekS+bPQ7J1pb7ZJJ9b5PJ4dlmiFnp1pb76urh+7n/c/q458y4gYVUMCNPI5BPJvlBZTKpfidiViaQNb0XedeQssAe8jp1jOXfM1BTUSBPI7Uylh1/ZdIZTSs38q4hSffcczBF0cc9pyoFqCWNQD6W3GiZQFaU23Yvly8vGyz7uOdUpWDAWi28y5umt/2onFoZizKphaLUSNmccszpC3LkGKhQ/2kr6dTKWJRJZ+yfKedZVlkSc5pqLL95YXRaLwLLi+5tP5iRL7CsamW/upUlVa4xZNwHdCRUEZiYkXeoSTKsSgvZujllzr2M+5BpDE7byz8E8tC6DBCh0iRjbH/AJjN0qPWMZt40ve3HoFMrXS8mNk0PjHWBMbVNZqSBkhfiX6EKUiv0Iw+tzX7cbThxQrp27fDzk0mWdhmq9fXst6V5MX7f9GHHHfQj70pKtdDb2/lBXAq3CSfWtE3M1TvzSANhCQJ5aKkFiCIhfvAsWy/oM8hPp9LmprSykn29spJ9HeMMlx2vWCYv39L2Y9A5cvd08plFeWIpzJgXrRf0nZvv+/pVxLyJC50SOXIcUpQnXl2Vrl5t/vmL1gtOnuw3R02OHAkiR47DitJATz0V5vMXrRf0nS7o+/pVsOMVS4wjkMe64Na3EAFi0b1dtF7Q96Jw39evik1cWCQv39L2o9MceUq50NSU6W1etF7Q97+Xvq8P1KCkD5ZogoWi9hTdW8n92LEsoC9a8O17Ubjv6wMVFQXy4S92prZBJyVF9zYPi3NAY+Nd7EwtF5qSKvfw+nXpscf28uisWwDBDD+Qp7RBJzV593aZS5ekT31KeuIJOg8CgQw/kFO61Z7ZvV1drfa+Gzekt946+BxbzoHahh/IJUq32jSdZpuHtrb2fliurkpHj1b/rLI13KRlgAMaBXIz+7iZvWxmt83sUAIeI7L/h+XVq9K3vlV8HF2RMjl3DoQADmk6I39J0sck/SzAWBCDULPdWWDf2jqcRz96VDp27PB73nxzeUMtOgECh9zV5M3uflGSzCzMaNCv+Z4es9muVD8dNXvf2bNZ6uTkyb2F5iefPNhG99q17Ho//7l04UL+OFLaWg90JEgduZn9h6R/dvdSxeE0zYpUmUZS29uHg3LdIF90vZUV6dat/HFI6TS7AgKrXUduZj81s5dyHo9WHMBpM9sxs53d3d0qb0VXls12Q+eni66XF8Rnr6ecFDhkaSB39w+5+0M5j2eqXMjdz7v7hrtvrK2t1R/xGHVVpbFs81To/HTR9WaHPeS9nnJS4JBxlB+moChYh54F1+1WKIXPTxdd7/TpxeOgnBQ4KK8BS9mHpI9KuiLp/yT9j6Tnyrxv8CcEVbWoE1/Ipl9lOv4taiTVRgOyRd0RaWgFHKDRNs1KwaJFxsuXwzX9anoqDifVAL0ab9OsFCxKWYRs+tU0NdL2QRQAaiGQx2BRsA5ZpRHih0KT/DS7MoFWEMhjsChYh6zS6Lt0j12ZQCvIkcci5EabGK6Th0M+gEbIkfdtWW64q5K6Nq+z7Husm9ohrw4sllfK0vZjdOWHeWV/Zu5nzvQ9snDKljZWPfCYQ5KBt2m0hy/HoKj+2mw4AalsjXnV+vAuDs+mZh2JKArk5Mi7sOiQ4rYaUnWtrfx323l1auOREHLkfVqUA26rIVXX2jrkuu3Ds6mkwQAQyLtw7lw2g8zTVkOqrrVV2th2yST9zTEABPIuTKfSZz5zOJi32ZBqpquKj7r17mWqedrsdtj2jB/oQl7ivO3H6BY7Z/poSBVzxUfX48u7/7HfI2AfUbUSuTYCShcVH010Ob5F95eqFSSiKJBTtRKT0FUrVSs+uq6a6XKnZ9POj0AEiqpWGh2+jMCm07CB8+TJ/OCVl/9t4+DlkONrikVNDBiLnUNWtuJje1va3Oy+aqbLJl4samLACORDVqbiYzYTX3TgcZ/jC6Xvzo9Ai8iRj11R7nhmdVW6erWz4bQq5Z2zgNjZiSLLZtxvvJHO7tJlOLQZA0UgH7Iym4GW5Yjfeiud3aVtoY0uIkcg71tbQaJs75a83PG8MVd2pN4DB6NAIO9Tm0GibO+W/QuORfqq7IhhJpx6DxyMAoudfWpzk0qdzTYxtXSNZSwcT4eIsNgZozY3qdSpm+6yHHCZWGbC1J8jAQTyPrUZJOrWTcdS2RHLTkzqz5EAAnmf2gwSMc2u64hlJpz6fcQokCPvG5tU8sWSIwciQtOsWIVulDUUs3vCDzlgKQI54sUPOaAUcuQAkDgCOQAkjkAOAIkjkANA4gjkAJA4AjkAJI5ADgCJaxTIzezLZvZbM/uNmf3QzN4RaFwAgJKazsifl/SQu79X0u8lfan5kAAAVTQK5O7+E3e/eefLX0i6v/mQAABVhMyRPyHpx0V/aWanzWzHzHZ2d3cDXhYAxm1prxUz+6mkd+X81Vl3f+bOa85Kuimp8Cwudz8v6byUdT+sNVoAwCFLZ+Tu/iF3fyjnMQvim5I+ImnqffTERVxiOGcTGJlG3Q/N7GFJX5D09+5+fdnrMXDzPcRnh0lLdDEEWtQ0R/5VSfdKet7MXjSzrwcYE1IVyzmbwMg0mpG7+1+FGggGIJZzNoGRYWcnwonlnE1gZAjkCIcT54FeEMgRDifOA73gzE6ExTmbQOeYkQNA4gjkAJA4AjkAJI5ADgCJI5ADQOKsjz5XZrYr6VLnFy7vhKSrfQ8iUtybxbg/xbg3xcrem4m7r80/2Usgj52Z7bj7Rt/jiBH3ZjHuTzHuTbGm94bUCgAkjkAOAIkjkOc73/cAIsa9WYz7U4x7U6zRvSFHDgCJY0YOAIkjkANA4gjkBczsy2b2WzP7jZn90Mze0feYYmFmHzezl83stplRTqbs/Foz+52ZvWJmX+x7PDExs2+a2Z/N7KW+xxITM3vAzP7dzC7e+f/pybqfRSAv9rykh9z9vZJ+L+lLPY8nJi9J+pikn/U9kBiY2Yqkr0n6sKQHJX3SzB7sd1RR+bakh/seRIRuSvq8u/+NpL+V9E91/7shkBdw95+4+807X/5C0v19jicm7n7R3X/X9zgi8n5Jr7j7H9z9LUnflfRoz2OKhrv/TNL/9j2O2Lj7a+7+6zv//Iaki5Luq/NZBPJynpD0474HgWjdJ+mP+76+opr/Q2KczGxd0vsk/bLO+0d9QpCZ/VTSu3L+6qy7P3PnNWeV/Qq03eXY+lbm3uBtlvMcdb0oxczukfR9SZ9z99frfMaoA7m7f2jR35vZpqSPSPoHH1nB/bJ7gwOuSHpg39f3S/pTT2NBQszsqLIgvu3uP6j7OaRWCpjZw5K+IOkf3f163+NB1H4l6T1m9m4zOybpE5J+1POYEDkzM0nfkHTR3b/S5LMI5MW+KuleSc+b2Ytm9vW+BxQLM/uomV2R9HeS/s3Mnut7TH26syj+WUnPKVuw+p67v9zvqOJhZt+R9F+S/trMrpjZp/seUyQ+IOlxSR+8E2NeNLNH6nwQW/QBIHHMyAEgcQRyAEgcgRwAEkcgB4DEEcgBIHEEcgBIHIEcABL3/8mgc8zXsf6rAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for idx, label in enumerate(labels):\n", " if label:\n", " plt.scatter(df1.loc[idx][\"x\"], df1.loc[idx][\"y\"], color=\"b\") \n", " else:\n", " plt.scatter(df1.loc[idx][\"x\"], df1.loc[idx][\"y\"], color=\"r\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyQUBOのSAのときと同様の結果を得ることができました。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## QBoost\n", "\n", "QBoostは量子アニーリングを用いたアンサンブル学習の一つです。アンサンブル学習は弱い予測器を多数用意して、その予測器の各予測結果の組み合わせて最終的な予測結果を得ます。 \n", "QBoostでは、与えられた学習データに対して最適な学習器の組み合わせを、量子アニーリングを用いて最適化を行います。今回は分類問題を扱います。 \n", "\n", "$D$個の学習データの集合を$\\{\\vec x^{(d)}\\}(d=1, ..., D)$、対応するラベルを$\\{y^{(d)}\\}(d=1, ..., D), y^{(d)}\\in \\{-1, 1\\}$とします。また、$N$個の弱学習器の(関数の)集合を$\\{C_i\\}(i=1, ..., N)$とします。あるデータ$\\vec x^{(d)}$ に対して、$C_i(\\vec x^{(d)})\\in \\{-1, 1\\}$です。このとき、最終的な分類のラベルは以下のようになります。\n", "\n", "$${\\rm sgn}\\left( \\sum_{i=1}^{N} w_i C_i({\\vec x}^{(d)})\\right)$$\n", "\n", "ここで$w_i\\in\\{0, 1\\} (i=1, ..., N)$とします。これは各予測器の重み(予測器を最終的な予測に採用するか採用しないかのbool値)です。 \n", "QBoostではこの$w_i$を、弱学習器の個数を刈り込みつつ最終的な予測が教師データに一致するように組み合わせの最適化を行います。\n", "この問題におけるハミルトニアンは、以下のようになります。\n", "\n", "$$H(\\vec w) = \\sum_{d=1}^{D} \\left( \\frac{1}{N}\\sum_{i=1}^{N} w_i C_i(\\vec x^{(d)})-y^{(d)} \\right)^2 + \\lambda \\sum _i^N w_i$$\n", "\n", "このハミルトニアンの第一項目は、弱分類器と正解ラベルの差を表しています。第二項目は最終的な分類器に採用する弱分類器の個数の程度を表しています。$\\lambda$は第二項の弱分類器の個数がハミルトニアンにどのくらい影響するかを調節する正則化パラメータにです。 \n", "第一項をコスト(目的関数)、第二項を制約として、このハミルトニアンの最適化を行います。量子アニーリングで最小化することにより、学習データに最も適合するような弱分類器の組み合わせを得ることができます。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### スクリプト\n", "\n", "それでは実際にQBoostを試してみましょう。学習データにはscikit-learnの癌識別のデータセットを使用します。簡単のために、学習に用いるのは\"0\"と\"1\"の2つの文字種のみとします。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# ライブラリをインポート\n", "import pandas as pd \n", "from scipy import stats \n", "from sklearn import datasets\n", "from sklearn import metrics" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# データのロード\n", "cancerdata = datasets.load_breast_cancer()\n", "# 学習用データと検証用データの個数の設定\n", "num_train = 450" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "今回はデモンストレーションのために、ノイズとなる特徴量がある場合を考えます。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(569, 60)\n" ] } ], "source": [ "data_noisy = np.concatenate((cancerdata.data, np.random.rand(cancerdata.data.shape[0], 30)), axis=1)\n", "print(data_noisy.shape)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# labelを{0, 1}から{-1, 1}に変換\n", "labels = (cancerdata.target-0.5) * 2" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# データセットを学習用と検証用に分割\n", "X_train = data_noisy[:num_train, :]\n", "X_test = data_noisy[num_train:, :]\n", "y_train = labels[:num_train]\n", "y_test = labels[num_train:]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# 弱学習器の結果から\n", "def aggre_mean(Y_list):\n", " return ((np.mean(Y_list, axis=0)>0)-0.5) * 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 弱学習器の作成\n", "\n", "scikit-learnで弱分類器を作成します。今回は弱分類器としてdecision stumpを用いましょう。decision stumpとは、一層の決定木のことです。今回は弱分類器として用いるので、分割に用いる特徴量はランダムに選ぶこととします(一層のランダムフォレストを行うという理解で良いでしょう)。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# 必要なライブラリのインポート\n", "from sklearn.tree import DecisionTreeClassifier as DTC\n", "\n", "# 弱分類機の個数の設定\n", "num_clf = 32\n", "# bootstrap samplingで、一つのサンプルに対して取り出すサンブル個数\n", "sample_train = 40\n", "# モデルの設定\n", "models = [DTC(splitter=\"random\",max_depth=1) for i in range(num_clf)]\n", "for model in models:\n", " # ランダムに抽出\n", " train_idx = np.random.choice(np.arange(X_train.shape[0]), sample_train)\n", " # 説明変数と目的変数から決定木を作成\n", " model.fit(X=X_train[train_idx], y=y_train[train_idx])\n", "y_pred_list_train = []\n", "for model in models:\n", " # 作成したモデルを用いて予測を実行\n", " y_pred_list_train.append(model.predict(X_train))\n", "y_pred_list_train = np.asanyarray(y_pred_list_train)\n", "y_pred_train =np.sign(y_pred_list_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "すべての弱学習器を最終的な分類器としたときの精度を見てみましょう。以後、この組み合わせをbaselineとします。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9495798319327731\n" ] } ], "source": [ "y_pred_list_test = []\n", "for model in models:\n", " # 検証データで実行\n", " y_pred_list_test.append(model.predict(X_test))\n", " \n", "y_pred_list_test = np.array(y_pred_list_test)\n", "y_pred_test = np.sign(np.sum(y_pred_list_test,axis=0))\n", "# 予測結果の精度のスコア計算\n", "acc_test_base = metrics.accuracy_score(y_true=y_test, y_pred=y_pred_test)\n", "print(acc_test_base)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# QBoostを行うクラスを定義\n", "class QBoost():\n", " def __init__(self, y_train, ys_pred):\n", " self.num_clf = ys_pred.shape[0]\n", " # バイナリ変数を定義\n", " self.Ws = Array.create(\"weight\", shape = self.num_clf, vartype=\"BINARY\")\n", " # 正規化項の大きさ(ハイパーパラメータ)をPyQUBOのPlaceholderで定義\n", " self.param_lamda = Placeholder(\"norm\")\n", " # 弱分類器の組み合わせのハミルトニアンをセット\n", " self.H_clf = sum( [ (1/self.num_clf * sum([W*C for W, C in zip(self.Ws, y_clf)])- y_true)**2 for y_true, y_clf in zip(y_train, ys_pred.T)\n", " ])\n", " # 正規化項のハミルトニアンを制約としてセット\n", " self.H_norm = Constraint(sum([W for W in self.Ws]), \"norm\")\n", " # 全ハミルトニアンをセット\n", " self.H = self.H_clf + self.H_norm * self.param_lamda\n", " # モデルをコンパイル\n", " self.model = self.H.compile()\n", " # QUBOに変換する関数を定義\n", " def to_qubo(self, norm_param=1):\n", " # ハイパーパラメータの大きさを設定\n", " self.feed_dict = {'norm': norm_param}\n", " return self.model.to_qubo(feed_dict=self.feed_dict)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "qboost = QBoost(y_train=y_train, ys_pred=y_pred_list_train)\n", "# lambda=3としてQUBO作成\n", "qubo = qboost.to_qubo(3)[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### D-Waveサンプラーを用いてQBoostを実行" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# 必要なライブラリをインポート\n", "from dwave.system.samplers import DWaveSampler\n", "from dwave.system.composites import EmbeddingComposite" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dw = DWaveSampler(endpoint='https://cloud.dwavesys.com/sapi/', \n", " token='xxxx', \n", " solver='DW_2000Q_VFYC_6')\n", "sampler = EmbeddingComposite(dw)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# D-Waveサンプラーで計算\n", "sampleset = sampler.sample_qubo(qubo, num_reads=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 結果の確認\n", "print(sampleset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 各計算結果をPyQUBOでdecode\n", "decoded_solutions = []\n", "brokens = []\n", "energies =[]\n", "\n", "decoded_sol = qboost.model.decode_dimod_response(sampleset, feed_dict=qboost.feed_dict)\n", "for d_sol, broken, energy in decoded_sol:\n", " decoded_solutions.append(d_sol)\n", " brokens.append(broken)\n", " energies.append(energy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "D-Waveで得られた弱分類器の組み合わせを使った場合の、学習データ/検証データでの精度を確認しましょう。" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "accs_train_Dwaves = []\n", "accs_test_Dwaves = []\n", "for decoded_solution in decoded_solutions:\n", " idx_clf_DWave=[]\n", " for key, val in decoded_solution[\"weight\"].items():\n", " if val == 1:\n", " idx_clf_DWave.append(int(key))\n", " y_pred_train_DWave = np.sign(np.sum(y_pred_list_train[idx_clf_DWave, :], axis=0))\n", " y_pred_test_DWave = np.sign(np.sum(y_pred_list_test[idx_clf_DWave, :], axis=0))\n", " acc_train_DWave = metrics.accuracy_score(y_true=y_train, y_pred=y_pred_train_DWave)\n", " acc_test_DWave= metrics.accuracy_score(y_true=y_test, y_pred=y_pred_test_DWave)\n", " accs_train_Dwaves.append(acc_train_DWave)\n", " accs_test_Dwaves.append(acc_test_DWave)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "横軸をエネルギー、縦軸を精度のグラフを作成しましょう。" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "image/svg+xml": [ "\r\n", "\r\n", "\r\n", "\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", "\r\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(12, 8))\n", "plt.scatter(energies, accs_train_Dwaves, label=\"train\" )\n", "plt.scatter(energies, accs_test_Dwaves, label=\"test\")\n", "plt.xlabel(\"energy: D-wave\")\n", "plt.ylabel(\"accuracy\")\n", "plt.title(\"relationship between energy and accuracy\")\n", "plt.grid()\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "base accuracy is 0.9411764705882353\n", "max accuracy of QBoost is 0.957983193277311\n", "average accuracy of QBoost is 0.9398183515830576\n" ] } ], "source": [ "print(\"base accuracy is {}\".format(acc_test_base))\n", "print(\"max accuracy of QBoost is {}\".format(max(accs_test_Dwaves)))\n", "print(\"average accuracy of QBoost is {}\".format(np.mean(np.asarray(accs_test_Dwaves))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "D-Waveによるサンプリングは短時間で数百サンプリング以上が実行できます。精度が最大になる結果を使えば、baselineよりも高精度の分類器を作成することが可能です。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 2 }