प्रति-आर्म सुविधाओं के साथ बहु-सशस्त्र डाकुओं पर एक ट्यूटोरियल

शुरू हो जाओ

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें

यह ट्यूटोरियल प्रासंगिक डाकुओं की समस्याओं के लिए TF-Agents लाइब्रेरी का उपयोग करने के तरीके पर एक चरण-दर-चरण मार्गदर्शिका है, जहां क्रियाओं (हथियारों) की अपनी विशेषताएं हैं, जैसे कि सुविधाओं द्वारा दर्शाई गई फिल्मों की सूची (शैली, रिलीज का वर्ष, ...)

शर्त

यह माना जाता है कि पाठक TF-एजेंटों की बेंडिट पुस्तकालय के साथ कुछ हद तक परिचित है, विशेष रूप से, के माध्यम से काम किया है TF-एजेंटों में डाकू के लिए ट्यूटोरियल इस ट्यूटोरियल पढ़ने से पहले।

आर्म फीचर्स के साथ मल्टी-आर्म्ड बैंडिट्स

"क्लासिक" प्रासंगिक मल्टी-आर्म्ड बैंडिट्स सेटिंग में, एक एजेंट को हर समय कदम पर एक संदर्भ वेक्टर (उर्फ अवलोकन) प्राप्त होता है और उसे अपने संचयी इनाम को अधिकतम करने के लिए गिने-चुने कार्यों (हथियारों) के एक सीमित सेट से चुनना होता है।

अब उस परिदृश्य पर विचार करें जहां एक एजेंट उपयोगकर्ता को अगली फिल्म देखने की सिफारिश करता है। हर बार जब कोई निर्णय लेना होता है, तो एजेंट को संदर्भ के रूप में उपयोगकर्ता के बारे में कुछ जानकारी प्राप्त होती है (इतिहास देखें, शैली वरीयता, आदि ...), साथ ही साथ चुनने के लिए फिल्मों की सूची।

हम संदर्भ के रूप में उपयोगकर्ता जानकारी होने से इस समस्या को तैयार करने के लिए कोशिश कर सकते और हथियारों होगा movie_1, movie_2, ..., movie_K , लेकिन इस दृष्टिकोण कई कमियों है:

  • सिस्टम में सभी फिल्मों के लिए क्रियाओं की संख्या होनी चाहिए और एक नई फिल्म जोड़ना बोझिल है।
  • एजेंट को हर फिल्म के लिए एक मॉडल सीखना होता है।
  • फिल्मों के बीच समानता को ध्यान में नहीं रखा जाता है।

फिल्मों को क्रमांकित करने के बजाय, हम कुछ अधिक सहज ज्ञान युक्त कर सकते हैं: हम शैली, लंबाई, कास्ट, रेटिंग, वर्ष इत्यादि सहित सुविधाओं के एक सेट के साथ फिल्मों का प्रतिनिधित्व कर सकते हैं। इस दृष्टिकोण के फायदे कई गुना हैं:

  • फिल्मों में सामान्यीकरण।
  • एजेंट केवल एक इनाम फ़ंक्शन सीखता है जो मॉडल उपयोगकर्ता और मूवी सुविधाओं के साथ पुरस्कृत करता है।
  • सिस्टम से निकालना, या नई फिल्मों को पेश करना आसान है।

इस नई सेटिंग में, हर समय चरण में क्रियाओं की संख्या समान होना भी आवश्यक नहीं है।

टीएफ-एजेंटों में प्रति-आर्म बैंडिट्स

TF-Agents Bandit सुइट को विकसित किया गया है ताकि कोई इसे प्रति-आर्म केस के लिए भी उपयोग कर सके। प्रति-आर्म वातावरण हैं, और अधिकांश नीतियां और एजेंट प्रति-आर्म मोड में भी काम कर सकते हैं।

इससे पहले कि हम एक उदाहरण कोडिंग में उतरें, हमें आवश्यक आयात की आवश्यकता है।

इंस्टालेशन

pip install tf-agents

आयात

import functools
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tf_agents.bandits.agents import lin_ucb_agent
from tf_agents.bandits.environments import stationary_stochastic_per_arm_py_environment as p_a_env
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import tf_py_environment
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts

nest = tf.nest

पैरामीटर -- चारों ओर खेलने के लिए स्वतंत्र महसूस करें

# The dimension of the global features.
GLOBAL_DIM = 40 
# The elements of the global feature will be integers in [-GLOBAL_BOUND, GLOBAL_BOUND).
GLOBAL_BOUND = 10 
# The dimension of the per-arm features.
PER_ARM_DIM = 50 
# The elements of the PER-ARM feature will be integers in [-PER_ARM_BOUND, PER_ARM_BOUND).
PER_ARM_BOUND = 6 
# The variance of the Gaussian distribution that generates the rewards.
VARIANCE = 100.0 
# The elements of the linear reward parameter will be integers in [-PARAM_BOUND, PARAM_BOUND).
PARAM_BOUND = 10 

NUM_ACTIONS = 70 
BATCH_SIZE = 20 

# Parameter for linear reward function acting on the
# concatenation of global and per-arm features.
reward_param = list(np.random.randint(
      -PARAM_BOUND, PARAM_BOUND, [GLOBAL_DIM + PER_ARM_DIM]))

एक साधारण प्रति-आर्म पर्यावरण

स्थिर स्टोकेस्टिक पर्यावरण, अन्य में विस्तार से बताया ट्यूटोरियल , एक प्रति हाथ के समकक्ष है।

प्रति-हाथ के वातावरण को आरंभ करने के लिए, किसी को उन कार्यों को परिभाषित करना होगा जो उत्पन्न करते हैं

  • वैश्विक और प्रति-हाथ विशेषताएं: इन कार्यों कोई इनपुट पैरामीटर है और जब नामक केवल एक (वैश्विक या प्रति-हाथ) फीचर वेक्टर उत्पन्न करते हैं।
  • पुरस्कार: यह समारोह एक वैश्विक और एक प्रति हाथ के फीचर वेक्टर के संयोजन पैरामीटर के रूप में लेता है, और एक इनाम उत्पन्न करता है। मूल रूप से यह वह कार्य है जिसे एजेंट को "अनुमान" करना होगा। यहां यह ध्यान देने योग्य है कि प्रति-हाथ के मामले में इनाम समारोह प्रत्येक हाथ के लिए समान है। यह क्लासिक दस्यु मामले से एक मूलभूत अंतर है, जहां एजेंट को प्रत्येक हाथ के लिए स्वतंत्र रूप से इनाम कार्यों का अनुमान लगाना होता है।
def global_context_sampling_fn():
  """This function generates a single global observation vector."""
  return np.random.randint(
      -GLOBAL_BOUND, GLOBAL_BOUND, [GLOBAL_DIM]).astype(np.float32)

def per_arm_context_sampling_fn():
  """"This function generates a single per-arm observation vector."""
  return np.random.randint(
      -PER_ARM_BOUND, PER_ARM_BOUND, [PER_ARM_DIM]).astype(np.float32)

def linear_normal_reward_fn(x):
  """This function generates a reward from the concatenated global and per-arm observations."""
  mu = np.dot(x, reward_param)
  return np.random.normal(mu, VARIANCE)

अब हम अपने पर्यावरण को आरंभ करने के लिए सुसज्जित हैं।

per_arm_py_env = p_a_env.StationaryStochasticPerArmPyEnvironment(
    global_context_sampling_fn,
    per_arm_context_sampling_fn,
    NUM_ACTIONS,
    linear_normal_reward_fn,
    batch_size=BATCH_SIZE
)
per_arm_tf_env = tf_py_environment.TFPyEnvironment(per_arm_py_env)

नीचे हम जांच सकते हैं कि यह वातावरण क्या पैदा करता है।

print('observation spec: ', per_arm_tf_env.observation_spec())
print('\nAn observation: ', per_arm_tf_env.reset().observation)

action = tf.zeros(BATCH_SIZE, dtype=tf.int32)
time_step = per_arm_tf_env.step(action)
print('\nRewards after taking an action: ', time_step.reward)
observation spec:  {'global': TensorSpec(shape=(40,), dtype=tf.float32, name=None), 'per_arm': TensorSpec(shape=(70, 50), dtype=tf.float32, name=None)}

An observation:  {'global': <tf.Tensor: shape=(20, 40), dtype=float32, numpy=
array([[ -9.,  -4.,  -3.,   3.,   5.,  -9.,   6.,  -5.,   4.,  -8.,  -6.,
         -1.,  -7.,  -5.,   7.,   8.,   2.,   5.,  -8.,   0.,  -4.,   4.,
         -1.,  -1.,  -4.,   6.,   8.,   6.,   9.,  -5.,  -1.,  -1.,   2.,
          5.,  -1.,  -8.,   1.,   0.,   0.,   5.],
       [  5.,   7.,   0.,   3.,  -8.,  -7.,  -5.,  -2.,  -8.,  -7.,  -7.,
         -8.,   5.,  -3.,   5.,   4.,  -5.,   2.,  -6., -10.,  -4.,  -2.,
          2.,  -1.,  -1.,   8.,  -7.,   7.,   2.,  -3., -10.,  -1.,  -4.,
         -7.,   3.,   4.,   8.,  -2.,   9.,   5.],
       [ -6.,  -2.,  -1.,  -1.,   6.,  -3.,   4.,   9.,   2.,  -2.,   3.,
          1.,   0.,  -7.,   5.,   5.,  -8.,  -4.,   5.,   7., -10.,  -4.,
          5.,   6.,   8., -10.,   7.,  -1.,  -8.,  -8.,  -6.,  -6.,   4.,
        -10.,  -8.,   3.,   8.,  -9.,  -5.,   8.],
       [ -1.,   8.,  -8.,  -7.,   9.,   2.,  -6.,   8.,   4.,  -2.,   1.,
          8.,  -4.,   3.,   1.,  -6.,  -9.,   3.,  -5.,   7.,  -9.,   6.,
          6.,  -3.,   1.,   2.,  -1.,   3.,   7.,   4.,  -8.,   1.,   1.,
          3.,  -4.,   1.,  -4.,  -5.,  -9.,   4.],
       [ -9.,   8.,   9.,  -8.,   2.,   9.,  -1.,  -9.,  -1.,   9.,  -8.,
         -4.,   1.,   1.,   9.,   6.,  -6., -10.,  -6.,   2.,   6.,  -4.,
         -7.,  -2.,  -7.,  -8.,  -4.,   5.,  -6.,  -1.,   8.,  -3.,  -7.,
          4.,  -9.,  -9.,   6.,  -9.,   6.,  -2.],
       [  0.,  -6.,  -5.,  -8., -10.,   2.,  -4.,   9.,   9.,  -1.,   5.,
         -7.,  -1.,  -3., -10., -10.,   3.,  -2.,  -7.,  -9.,  -4.,  -8.,
         -4.,  -1.,   7.,  -2.,  -4.,  -4.,   9.,   2.,  -2.,  -8.,   6.,
          5.,  -4.,   7.,   0.,   6.,  -3.,   2.],
       [  8.,   5.,   3.,   5.,   9.,   4., -10.,  -5.,  -4.,  -4.,  -5.,
          3.,   5.,  -4.,   9.,  -2.,  -7.,  -6.,  -2.,  -8.,  -7., -10.,
          0.,  -2.,   3.,   1., -10.,  -8.,   3.,   9.,  -5.,  -6.,   1.,
         -7.,  -1.,   3.,  -7.,  -2.,   1.,  -1.],
       [  3.,   9.,   8.,   6.,  -2.,   9.,   9.,   7.,   0.,   5.,  -5.,
          6.,   9.,   3.,   2.,   9.,   4.,  -1.,  -3.,   3.,  -1.,  -4.,
         -9.,  -1.,  -3.,   8.,   0.,   4.,  -1.,   4.,  -3.,   4.,  -5.,
         -3.,  -6.,  -4.,   7.,  -9.,  -7.,  -1.],
       [  5.,  -1.,   9.,  -5.,   8.,   7.,  -7.,  -5.,   0.,  -4.,  -5.,
          6.,  -3.,  -1.,   7.,   3.,  -7.,  -9.,   6.,   4.,   9.,   6.,
         -3.,   3.,  -2.,  -6.,  -4.,  -7.,  -5.,  -6.,  -2.,  -1.,  -9.,
         -4.,  -9.,  -2.,  -7.,  -6.,  -3.,   6.],
       [ -7.,   1.,  -8.,   1.,  -8.,  -9.,   5.,   1.,  -4.,  -2.,  -5.,
          3.,  -1.,  -4.,  -4.,   5.,   0., -10.,  -4.,  -1.,  -5.,   3.,
          8.,  -5.,  -4., -10.,  -8.,  -6., -10.,  -1.,  -6.,   1.,   7.,
          8.,   6.,  -2.,  -4.,  -9.,   7.,  -1.],
       [ -2.,   3.,   8.,  -5.,   0.,   5.,   8.,  -5.,   6.,  -8.,   5.,
          8.,  -5.,  -5.,  -5., -10.,   4.,   8.,  -4.,  -7.,   4.,  -6.,
         -9.,  -8.,  -5.,   4.,  -1.,  -2.,  -7., -10.,  -6.,  -8.,  -6.,
          3.,   1.,   6.,   9.,   6.,  -8.,  -3.],
       [  9.,  -6.,  -2., -10.,   2.,  -8.,   8.,  -7.,  -5.,   8., -10.,
          4.,  -5.,   9.,   7.,   9.,  -2.,  -9.,  -5.,  -2.,   9.,   0.,
         -6.,   2.,   4.,   6.,  -7.,  -4.,  -5.,  -7.,  -8.,  -8.,   8.,
         -7.,  -1.,  -5.,   0.,  -7.,   7.,  -6.],
       [ -1.,  -3.,   1.,   8.,   4.,   7.,  -1.,  -8.,  -4.,   6.,   9.,
          5., -10.,   4.,  -4.,   5.,  -2.,   0.,   3.,   4.,   3.,  -5.,
         -2.,   7.,   4.,  -4.,  -9.,   9.,  -6.,  -5.,  -8.,   4., -10.,
         -6.,   3.,   0.,   6., -10.,   4.,   3.],
       [  8.,   8.,  -5.,   0.,  -7.,   5.,  -6.,  -8.,   2.,  -3.,  -5.,
          5.,   0.,   6., -10.,   3.,  -4.,   1.,  -8.,  -9.,   6.,  -5.,
          5., -10.,   1.,   0.,   3.,   5.,   2.,  -9.,  -6.,   9.,   7.,
          9., -10.,   4.,  -4., -10.,  -5.,   1.],
       [  8.,   3.,  -5.,  -2.,  -8.,  -6.,   6.,  -7.,   8.,   1.,  -8.,
          0.,  -2.,   3.,  -6.,   0., -10.,   6.,  -8.,  -2.,  -5.,   4.,
         -1.,  -9.,  -7.,   3.,  -1.,  -4.,  -1., -10.,  -3.,  -7.,  -3.,
          4.,  -7.,  -6.,  -1.,   9.,  -3.,   2.],
       [  8.,   7.,   6.,  -5.,  -3.,   0.,   1.,  -2.,   0.,  -3.,   9.,
         -8.,   5.,   1.,   1.,   1.,  -5.,   4.,  -4.,   0.,  -4.,  -3.,
          7., -10.,   3.,   6.,   4.,   5.,   2.,  -7.,   0.,  -3.,  -5.,
          2.,  -6.,   4.,   5.,   8.,  -1.,  -3.],
       [  8.,  -9.,  -4.,   8.,  -2.,   9.,   5.,   5.,  -3.,  -4.,   0.,
         -5.,   5.,  -2., -10.,  -4.,  -3.,   5.,   8.,   6.,  -2.,  -2.,
         -1.,  -8.,  -5.,  -9.,   1.,  -1.,   5.,   6.,   4.,   9.,  -5.,
          6.,  -2.,   7.,  -7.,  -9.,   4.,   2.],
       [  2.,   4.,   6.,   2.,   6.,  -6.,  -2.,   5.,   8.,   1.,   3.,
          8.,   6.,   9.,  -3.,  -1.,   4.,   7.,  -5.,   7.,   0., -10.,
          9.,  -6.,  -4.,  -7.,   1.,  -2.,  -2.,   3.,  -1.,   2.,   5.,
          8.,   4.,  -9.,   1.,  -4.,   9.,   6.],
       [ -8.,  -5.,   9.,   3.,   9., -10.,  -8.,   3.,  -8.,   0.,  -4.,
         -8.,  -3.,  -4.,  -3.,   0.,   8.,   3., -10.,   7.,   7.,  -3.,
          8.,   4.,  -3.,   9.,   3.,   7.,   2.,   7.,  -8.,  -3.,  -4.,
         -7.,   3.,  -9., -10.,   2.,   5.,   7.],
       [  5.,  -7.,  -8.,   6.,  -8.,   1.,  -8.,   4.,   2.,   6.,  -6.,
         -5.,   4.,  -1.,   3.,  -8.,  -3.,   6.,   5.,  -5.,   1.,  -7.,
          8., -10.,   8.,   1.,   3.,   7.,   2.,   2.,  -1.,   1.,  -3.,
          7.,   1.,   6.,  -6.,   0.,  -9.,   6.]], dtype=float32)>, 'per_arm': <tf.Tensor: shape=(20, 70, 50), dtype=float32, numpy=
array([[[ 5., -6.,  4., ..., -3.,  3.,  4.],
        [-5., -6., -4., ...,  3.,  4., -4.],
        [ 1., -1.,  5., ..., -1., -3.,  1.],
        ...,
        [ 3.,  3., -5., ...,  4.,  4.,  0.],
        [ 5.,  1., -3., ..., -2., -2., -3.],
        [-6.,  4.,  2., ...,  4.,  5., -5.]],

       [[-5., -3.,  1., ..., -2., -1.,  1.],
        [ 1.,  4., -1., ..., -1., -4., -4.],
        [ 4., -6.,  5., ...,  2., -2.,  4.],
        ...,
        [ 0.,  4., -4., ..., -1., -3.,  1.],
        [ 3.,  4.,  5., ..., -5., -2., -2.],
        [ 0.,  4., -3., ...,  5.,  1.,  3.]],

       [[-2., -6., -6., ..., -6.,  1., -5.],
        [ 4.,  5.,  5., ...,  1.,  4., -4.],
        [ 0.,  0., -3., ..., -5.,  0., -2.],
        ...,
        [-3., -1.,  4., ...,  5., -2.,  5.],
        [-3., -6., -2., ...,  3.,  1., -5.],
        [ 5., -3., -5., ..., -4.,  4., -5.]],

       ...,

       [[ 4.,  3.,  0., ...,  1., -6.,  4.],
        [-5., -3.,  5., ...,  0., -1., -5.],
        [ 0.,  4.,  3., ..., -2.,  1., -3.],
        ...,
        [-5., -2., -5., ..., -5., -5., -2.],
        [-2.,  5.,  4., ..., -2., -2.,  2.],
        [-1., -4.,  4., ..., -5.,  2., -3.]],

       [[-1., -4.,  4., ..., -3., -5.,  4.],
        [-4., -6., -2., ..., -1., -6.,  0.],
        [ 0.,  0.,  5., ...,  4., -4.,  0.],
        ...,
        [ 2.,  3.,  5., ..., -6., -5.,  5.],
        [-5., -5.,  2., ...,  0.,  4., -2.],
        [ 4., -5., -4., ..., -5., -5., -1.]],

       [[ 3.,  0.,  2., ...,  2.,  1., -3.],
        [-5., -4.,  3., ..., -6.,  0., -2.],
        [-4., -5.,  3., ..., -6., -3.,  0.],
        ...,
        [-6., -6.,  4., ..., -1., -5., -2.],
        [-4.,  3., -1., ...,  1.,  4.,  4.],
        [ 5.,  2.,  2., ..., -3.,  1., -4.]]], dtype=float32)>}

Rewards after taking an action:  tf.Tensor(
[-130.17787    344.98013    371.39893     75.433975   396.35742
 -176.46881     56.62174   -158.03278    491.3239    -156.10696
   -1.0527252 -264.42285     22.356699  -395.89832    125.951546
  142.99467   -322.3012     -24.547596  -159.47539    -44.123775 ], shape=(20,), dtype=float32)

हम देखते हैं कि अवलोकन युक्ति दो तत्वों वाला एक शब्दकोश है:

  • कुंजी के साथ एक 'global' : इस वैश्विक संदर्भ हिस्सा है, पैरामीटर मिलान आकार के साथ GLOBAL_DIM
  • कुंजी के साथ एक 'per_arm' : इस प्रति हाथ के संदर्भ में होता है, और इसके आकार है [NUM_ACTIONS, PER_ARM_DIM] । यह हिस्सा हर हाथ के लिए एक समय चरण में बांह सुविधाओं के लिए प्लेसहोल्डर है।

लिनयूसीबी एजेंट

LinUCB एजेंट समान रूप से नामित बैंडिट एल्गोरिथम को लागू करता है, जो रैखिक इनाम फ़ंक्शन के पैरामीटर का अनुमान लगाता है, जबकि अनुमान के आसपास एक आत्मविश्वास दीर्घवृत्त भी बनाए रखता है। एजेंट उस हाथ को चुनता है जिसमें उच्चतम अनुमानित अपेक्षित इनाम होता है, यह मानते हुए कि पैरामीटर आत्मविश्वास दीर्घवृत्त के भीतर है।

एक एजेंट बनाने के लिए अवलोकन और कार्रवाई विनिर्देश के ज्ञान की आवश्यकता होती है। जब एजेंट को परिभाषित करने के लिए, हम सेट बूलियन पैरामीटर accepts_per_arm_features करने के लिए सेट True

observation_spec = per_arm_tf_env.observation_spec()
time_step_spec = ts.time_step_spec(observation_spec)
action_spec = tensor_spec.BoundedTensorSpec(
    dtype=tf.int32, shape=(), minimum=0, maximum=NUM_ACTIONS - 1)

agent = lin_ucb_agent.LinearUCBAgent(time_step_spec=time_step_spec,
                                     action_spec=action_spec,
                                     accepts_per_arm_features=True)

प्रशिक्षण डेटा का प्रवाह

यह खंड यांत्रिकी में एक झलक देता है कि नीति से प्रशिक्षण तक प्रति-हाथ की विशेषताएं कैसे जाती हैं। बेझिझक अगले खंड (रिग्रेट मेट्रिक को परिभाषित करना) पर जाएं और यदि दिलचस्पी हो तो बाद में यहां वापस आएं।

सबसे पहले, आइए एजेंट में डेटा विनिर्देश पर एक नज़र डालें। training_data_spec एजेंट निर्दिष्ट क्या तत्वों और संरचना प्रशिक्षण डेटा होना चाहिए की विशेषता।

print('training data spec: ', agent.training_data_spec)
training data spec:  Trajectory(
{'action': BoundedTensorSpec(shape=(), dtype=tf.int32, name=None, minimum=array(0, dtype=int32), maximum=array(69, dtype=int32)),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': {'global': TensorSpec(shape=(40,), dtype=tf.float32, name=None)},
 'policy_info': PerArmPolicyInfo(log_probability=(), predicted_rewards_mean=(), multiobjective_scalarized_predicted_rewards_mean=(), predicted_rewards_optimistic=(), predicted_rewards_sampled=(), bandit_policy_type=(), chosen_arm_features=TensorSpec(shape=(50,), dtype=tf.float32, name=None)),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type')})

अगर हम करने के लिए करीब से देख है observation कल्पना का हिस्सा है, हम देखते हैं कि यह प्रति हाथ विशेषताएं नहीं होती!

print('observation spec in training: ', agent.training_data_spec.observation)
observation spec in training:  {'global': TensorSpec(shape=(40,), dtype=tf.float32, name=None)}

प्रति-हाथ सुविधाओं का क्या हुआ? इस सवाल का जवाब करने के लिए, पहले हम ध्यान दें कि जब LinUCB एजेंट गाड़ियों, यह सभी हथियारों का प्रति हाथ सुविधाओं की जरूरत नहीं है, यह केवल चुने हुए हाथ की उन की जरूरत है। इसलिए, यह आकार टेन्सर ड्रॉप करने समझ में आता है [BATCH_SIZE, NUM_ACTIONS, PER_ARM_DIM] , के रूप में यह बहुत बेकार है, खासकर यदि क्रियाओं की संख्या बड़ी है।

लेकिन फिर भी, चुने हुए हाथ की प्रति-हाथ की विशेषताएं कहीं न कहीं होनी चाहिए! यह अंत करने के लिए, हम यह सुनिश्चित कर लें कि LinUCB नीति भंडार के भीतर चुना बांह की सुविधाओं policy_info प्रशिक्षण डेटा के क्षेत्र:

print('chosen arm features: ', agent.training_data_spec.policy_info.chosen_arm_features)
chosen arm features:  TensorSpec(shape=(50,), dtype=tf.float32, name=None)

हम आकार कि से देख chosen_arm_features क्षेत्र एक हाथ की केवल फीचर वेक्टर है, और वह चुना हाथ हो जाएगा। ध्यान दें कि policy_info , और इसके साथ chosen_arm_features के रूप में हम प्रशिक्षण डेटा कल्पना निरीक्षण से देखा था,, प्रशिक्षण डेटा का हिस्सा है, और इस प्रकार यह प्रशिक्षण समय पर उपलब्ध है।

रिग्रेट मेट्रिक को परिभाषित करना

प्रशिक्षण लूप शुरू करने से पहले, हम कुछ उपयोगिता कार्यों को परिभाषित करते हैं जो हमारे एजेंट के अफसोस की गणना करने में मदद करते हैं। ये फ़ंक्शन क्रियाओं के सेट (उनके हाथ की विशेषताओं द्वारा दिए गए) और एजेंट से छिपे हुए रैखिक पैरामीटर को देखते हुए इष्टतम अपेक्षित इनाम को निर्धारित करने में मदद करते हैं।

def _all_rewards(observation, hidden_param):
  """Outputs rewards for all actions, given an observation."""
  hidden_param = tf.cast(hidden_param, dtype=tf.float32)
  global_obs = observation['global']
  per_arm_obs = observation['per_arm']
  num_actions = tf.shape(per_arm_obs)[1]
  tiled_global = tf.tile(
      tf.expand_dims(global_obs, axis=1), [1, num_actions, 1])
  concatenated = tf.concat([tiled_global, per_arm_obs], axis=-1)
  rewards = tf.linalg.matvec(concatenated, hidden_param)
  return rewards

def optimal_reward(observation):
  """Outputs the maximum expected reward for every element in the batch."""
  return tf.reduce_max(_all_rewards(observation, reward_param), axis=1)

regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward)

अब हम अपना दस्यु प्रशिक्षण लूप शुरू करने के लिए पूरी तरह तैयार हैं। नीचे दिया गया ड्राइवर नीति का उपयोग करके कार्यों को चुनने, रिप्ले बफर में चुने गए कार्यों के पुरस्कारों को संग्रहीत करने, पूर्वनिर्धारित खेद मीट्रिक की गणना करने और एजेंट के प्रशिक्षण चरण को निष्पादित करने का ध्यान रखता है।

num_iterations = 20 # @param
steps_per_loop = 1 # @param

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.policy.trajectory_spec,
    batch_size=BATCH_SIZE,
    max_length=steps_per_loop)

observers = [replay_buffer.add_batch, regret_metric]

driver = dynamic_step_driver.DynamicStepDriver(
    env=per_arm_tf_env,
    policy=agent.collect_policy,
    num_steps=steps_per_loop * BATCH_SIZE,
    observers=observers)

regret_values = []

for _ in range(num_iterations):
  driver.run()
  loss_info = agent.train(replay_buffer.gather_all())
  replay_buffer.clear()
  regret_values.append(regret_metric.result())
WARNING:tensorflow:From /tmp/ipykernel_12052/1190294793.py:21: ReplayBuffer.gather_all (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=True)` instead.

अब परिणाम देखते हैं। यदि हमने सब कुछ ठीक किया, तो एजेंट लीनियर रिवॉर्ड फंक्शन का अच्छी तरह से अनुमान लगाने में सक्षम होता है, और इस प्रकार पॉलिसी ऐसी कार्रवाइयाँ चुन सकती है, जिनका अपेक्षित इनाम इष्टतम के करीब होता है। यह हमारे ऊपर परिभाषित खेद मीट्रिक द्वारा इंगित किया गया है, जो नीचे जाता है और शून्य तक पहुंचता है।

plt.plot(regret_values)
plt.title('Regret of LinUCB on the Linear per-arm environment')
plt.xlabel('Number of Iterations')
_ = plt.ylabel('Average Regret')

पीएनजी

आगे क्या होगा?

ऊपर के उदाहरण है कार्यान्वित हमारे codebase में जहां रूप में अच्छी तरह सहित अन्य एजेंटों से चुन सकते हैं, तंत्रिका एप्सिलॉन-लालची एजेंट