diff --git a/.gitignore b/.gitignore index ba50ec8..6baa3ae 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,112 @@ +# Personal *.DS_Store - -# log and output files *.hlt *.log - -# developer environment .idea/ + +### Python template __pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.benchmarks/ +public/halite +public/models/variables/* +src/core/Halite.o +src/main.o +src/networking/Networking.o +visualize/other_hlt/ diff --git a/.gitmessage b/.gitmessage new file mode 100644 index 0000000..6194a76 --- /dev/null +++ b/.gitmessage @@ -0,0 +1,16 @@ + +Why this change was necessary: + +* + +This change addresses the need by: + +* + +Potential side-effects: + +* + +# 50-character subject line +# +# 72-character wrapped longer description. diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..ed97dc9 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,325 @@ +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Profiled execution. +profile=no + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + + +[MESSAGES CONTROL] + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. See also the "--disable" option for examples. +enable=indexing-exception,old-raise-syntax + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=arguments-differ,len-as-condition,invalid-unary-operand-type,design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager + + +# Set the cache size for astng objects. +cache-size=500 + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=yes + +# Tells whether to display a full report or only the messages +reports=yes + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Add a comment according to your evaluation note. This is used by the global +# evaluation report (RP0004). +comment=yes + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes=SQLObject + +# When zope mode is activated, add a predefined set of Zope acquired attributes +# to generated-members. +zope=no + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E0201 when accessed. Python regular +# expressions are accepted. +generated-members=REQUEST,acl_users,aq_parent + +# List of decorators that create context managers from functions, such as +# contextlib.contextmanager. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=yes + +# A regular expression matching the beginning of the name of dummy variables +# (i.e. not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + + +[BASIC] + +# Required attributes for module, separated by a comma +required-attributes= + +# List of builtins function names that should not be used, separated by a comma +bad-functions=apply,input,reduce + + +# Disable the report(s) with the given id(s). +# All non-Google reports are disabled by default. +disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 + +# Regular expression which should only match correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression which should only match correct module level names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression which should only match correct function names +function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct method names +method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct instance attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct attribute names in class +# bodies +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct list comprehension / +# generator expression variable names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main) + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=120 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=y + +# List of optional constructs for which whitespace checking is disabled +no-space-check= + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes= + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[CLASSES] + +# List of interface methods to ignore, separated by a comma. This is used for +# instance to not check methods defines in Zope's Interface base class. +ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls,class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=5 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception,StandardError,BaseException + + +[AST] + +# Maximum line length for lambdas +short-func-length=1 + +# List of module members that should be marked as deprecated. +# All of the string functions are listed in 4.1.4 Deprecated string functions +# in the Python 2.4 docs. +deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc + + +[DOCSTRING] + +# List of exceptions that do not need to be mentioned in the Raises section of +# a docstring. +ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError + + + +[TOKENS] + +# Number of spaces of indent required when the last token on the preceding line +# is an open (, [, or {. +indent-after-paren=4 + + +[Louis LINES] + +# Regexp for a proper copyright notice. +copyright=Copyright \d{4} Louis R?mus\. +All [Rr]ights [Rr]eserved\. diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..e362eab --- /dev/null +++ b/.travis.yml @@ -0,0 +1,38 @@ +sudo: false + +language: python + +python: + - 3.5 + +addons: + apt: + sources: + - ubuntu-toolchain-r-test + packages: + - g++-4.9 + +env: + global: + - CXX=g++-4.9 + +install: + - pip install -r requirements.txt + - make + +script: + # Tests + - python -m unittest discover -v + - find . -iname "*.py" | xargs pylint + + # Coverage checks + - python networking/start_game.py -sp "OpponentBot.py" + - py.test --cov=train tests/ + +after_success: + coveralls + +notifications: + email: + on_success: change + on_failure: change diff --git a/Makefile b/Makefile index 5ecfaae..47f6a4f 100644 --- a/Makefile +++ b/Makefile @@ -4,4 +4,34 @@ all: .PHONY: clean clean: - rm public/halite; cd src/; make clean; cd ..; \ No newline at end of file + rm *.hlt *.log public/halite; cd src/; make clean; cd ..; + +.PHONY: sync-nefeli +sync-nefeli: + rsync -a . mehlman@nefeli.math-info.univ-paris5.fr:/home/mehlman/Halite-Python-RL/ --delete + +.PHONY: get-nefeli +get-nefeli: + rsync -a --exclude 'public/halite' --exclude '*.o' mehlman@nefeli.math-info.univ-paris5.fr:/home/mehlman/Halite-Python-RL/ . #--delete + +.PHONY: sync-solon +sync-solon: + rsync -a --exclude 'public/halite' --exclude '*.o' . solon:/home/mehlman/Halite-Python-RL/ --delete + +.PHONY: get-solon +get-solon: + rsync -a --exclude 'public/halite' --exclude '*.o' solon:/home/mehlman/Halite-Python-RL/ . #--delete + +.PHONY: clear-agent +clear-agent: + rm -r './public/models/variables/$(AGENT)' + +.PHONY: server +server: + cd visualize;export FLASK_APP=visualize.py;flask run + +.PHONY: debug-server +debug-server: + cd visualize;FLASK_APP=visualize.py FLASK_DEBUG=1 python -m flask run + +#scp mehlman@nefeli.math-info.univ-paris5.fr:/home/mehlman/Halite-Python-RL/public/models/variables/vanilla-2 public/models/variables/vanilla-2 \ No newline at end of file diff --git a/README.md b/README.md index 3fd2a7a..ff705a0 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![Build Status](https://travis-ci.org/Edouard360/Halite-Python-RL.svg?branch=master)](https://travis-ci.org/Edouard360/Halite-Python-RL) [![Coverage Status](https://coveralls.io/repos/github/Edouard360/Halite-Python-RL/badge.svg?branch=master)](https://coveralls.io/github/Edouard360/Halite-Python-RL?branch=master) + # Halite-Python-RL

Halite Challenge Overview
@@ -25,91 +27,6 @@ Indeed, unlike chess or go, in the Halite turn-based game, we can do **multiple In this repository, we will mainly explore the solutions based on **Neural Networks**, and will start by a very simple MLP. This is inspired from a tutorial on Reinforcement Learning agent. +## Documentation & Articles -## Detailing the approach step by step - -We will explain the rules of the game in this section, along with our strategy for training the agent. To start simple, we will try to conquer a 3*3 map, where we are the only player (cf below). As we can see, this trained agent is already pretty efficient at conquering the map. - -
-

-conquermap -

- -### How does it start ? - -Each player starts with a single square of the map, and can either decide: - -- To **stay** in order to increase the strength of its square (action = STILL). - -- To **move** (/conquer) a neighboring square (action = NORTH, SOUTH, EAST, WEST). - -Conquering is only possible once the square's strength is high enough, such that a wise bot would first wait for its strength to increase before attacking any adjacent square, since **squares don't produce when they attack**. - -> To conquer a square, we must move in its direction having a strictly superior strength (action = NORTH, SOUTH, EAST, WEST) - -
- -The white numbers on the map below represent the current strength of the squares. On the left is just a snap of the initial state of the game. On the right you can see the strength of the blue square increment over time. This is because our agent decides to stay (action = STILL). - -

-the strength map - -

- -The increase in production is computed according to a fixed production map. In our example, we can see the blue square's strength increases by 4 at each turn. Each square has a different production speed, as represented by the white numbers below the squares. (cf below). On the left is also a snap of the initial game, whereas the game's dynamic is on the right. - -

-production map - -

- -This production map production is invariant over time, and is an information we should use to train our agent. Since we are interesting in maximizing our production, we should intuitively train our agent to target the squares with a high production rate. On the other hand, we should also consider the strength map, since squares with low strength are easier to conquer. - -

- -

- -### The Agent - -We will teach our agent with: - -- The successive **Game States**. -- The agent's **Moves** (initially random). -- The corresponding **Reward** for each Move (that we have to compute). - -For now, the Game State is a (3 * 3) * 3 matrix (width * height) * n_features, the features being: - -- The **Strength** of the Square -- The **Production** of the Square -- The **Owner** of the Square - -

-matrix - -

- -### The Reward - -
-As for the reward, we focus on the production. Since each square being conquered increase the total production of our land, the action leading to the conquest is rewarded according to the production rate of the conquered square. This strategy will best reward the conquest of highly productive squares. - -

- -

- -### Current results - -We train over 500 games and get significant improvements of the total reward obtained over time. - -

-screen shot 2017-09-26 at 17 34 04 -

- -On the right, you can observe the behaviour of the original, untrained bot, with random actions, whereas on the right, you can see the trained bot. - -

- - -

- -#### Isn't that amazing ? \ No newline at end of file +To get started, blog articles and documentation are available at this page. \ No newline at end of file diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..886d01a --- /dev/null +++ b/docs/README.md @@ -0,0 +1,3 @@ +# Documentation + +To see the docs, click [here](https://edouard360.github.io/Halite-Python-RL/). diff --git a/docs/_config.yml b/docs/_config.yml new file mode 100644 index 0000000..291b4ba --- /dev/null +++ b/docs/_config.yml @@ -0,0 +1,14 @@ +# Setup +theme: jekyll-theme-cayman + +title: Halite Challenge +tagline: A data science project + +author: + name: Edouard Mehlman + url: edouard.mehlman@polytechnique.edu + +collections: + documentation: + output: true + permalink: /:collection/:name # This is just display \ No newline at end of file diff --git a/docs/_documentation/first_steps.md b/docs/_documentation/first_steps.md new file mode 100644 index 0000000..7bbfe13 --- /dev/null +++ b/docs/_documentation/first_steps.md @@ -0,0 +1,44 @@ +--- +layout: default +title: "First Steps" + +--- + + +## Run the Bot + +In your console: + +`cd networking; python start_game.py` + +In another tab + +`cd public; python MyBot.py` + +This will run 1 game. Options can be added to starting the game, among which: + +`python start_game.py -g 5 -x 30 -z 50` + +Will run 5 games, of at most 30 turns, which at most squares of strength 50. + +All the options available for start_game might by listed (_with a clear description_) using the -h flag: + +`python start_game.py -h` + +## Visualize the Bot + +In your console: + +`cd visualize export FLASK_APP=visualize.py;flask run` + +Then either: + +Look at http://127.0.0.1:5000/performance.png for performance insights. + +Or at http://127.0.0.1:5000/ for games replay. + +## Working with PyCharm + +To run the Bot in Pycharm, you should provide a **mute** argument, since `MyBot.py` needs to know it's not on the Halite server, but running locally. + +Go to edit configuration and add the script argument `slave` (so that the bot knows it is in slave mode). \ No newline at end of file diff --git a/docs/_includes/center.css b/docs/_includes/center.css new file mode 100644 index 0000000..671898c --- /dev/null +++ b/docs/_includes/center.css @@ -0,0 +1,17 @@ +.list-unstyled { + padding-left: 0; + list-style: none; + } +.list-inline { + padding-left: 0; + margin-left: -5px; + list-style: none; +} +.list-inline > li { + display: inline-block; + padding-right: 5px; + padding-left: 5px; +} +.text-center { + text-align: center; +} \ No newline at end of file diff --git a/docs/_posts/2017-09-26-simple-approach.markdown b/docs/_posts/2017-09-26-simple-approach.markdown new file mode 100644 index 0000000..4107369 --- /dev/null +++ b/docs/_posts/2017-09-26-simple-approach.markdown @@ -0,0 +1,94 @@ +--- +layout: default +title: "A simple approach" +date: 2017-09-26 17:50:00 +categories: main +--- + +## Detailing the approach step by step + +We will explain the rules of the game in this section, along with our strategy for training the agent. To start simple, we will try to conquer a 3*3 map, where we are the only player (cf below). As we can see, this trained agent is already pretty efficient at conquering the map. + +
+

+conquermap +

+ + +### How does it start ? + +Each player starts with a single square of the map, and can either decide: + +- To **stay** in order to increase the strength of its square (action = STILL). + +- To **move** (/conquer) a neighboring square (action = NORTH, SOUTH, EAST, WEST). + +Conquering is only possible once the square's strength is high enough, such that a wise bot would first wait for its strength to increase before attacking any adjacent square, since **squares don't produce when they attack**. + +> To conquer a square, we must move in its direction having a strictly superior strength (action = NORTH, SOUTH, EAST, WEST) + +
+ +The white numbers on the map below represent the current strength of the squares. On the left is just a snap of the initial state of the game. On the right you can see the strength of the blue square increment over time. This is because our agent decides to stay (action = STILL). + +

+the strength map + +

+ +The increase in production is computed according to a fixed production map. In our example, we can see the blue square's strength increases by 4 at each turn. Each square has a different production speed, as represented by the white numbers below the squares. (cf below). On the left is also a snap of the initial game, whereas the game's dynamic is on the right. + +

+production map + +

+ +This production map production is invariant over time, and is an information we should use to train our agent. Since we are interesting in maximizing our production, we should intuitively train our agent to target the squares with a high production rate. On the other hand, we should also consider the strength map, since squares with low strength are easier to conquer. + +

+ +

+ +### The Agent + +We will teach our agent with: + +- The successive **Game States**. +- The agent's **Moves** (initially random). +- The corresponding **Reward** for each Move (that we have to compute). + +For now, the Game State is a (3 * 3) * 3 matrix (width * height) * n_features, the features being: + +- The **Strength** of the Square +- The **Production** of the Square +- The **Owner** of the Square + +

+matrix + +

+ +### The Reward + +
+As for the reward, we focus on the production. Since each square being conquered increase the total production of our land, the action leading to the conquest is rewarded according to the production rate of the conquered square. This strategy will best reward the conquest of highly productive squares. + +

+ +

+ +### Current results + +We train over 500 games and get significant improvements of the total reward obtained over time. + +

+screen shot 2017-09-26 at 17 34 04 +

+ +On the right, you can observe the behaviour of the original, untrained bot, with random actions, whereas on the right, you can see the trained bot. + +

+ + +

+ diff --git a/docs/_posts/2017-10-05-reward-importance.markdown b/docs/_posts/2017-10-05-reward-importance.markdown new file mode 100644 index 0000000..5b0eb80 --- /dev/null +++ b/docs/_posts/2017-10-05-reward-importance.markdown @@ -0,0 +1,160 @@ +--- +layout: default +title: "The reward importance" +date: 2017-10-05 16:30:00 +categories: main +--- + + + +# The reward importance + +## The impact of the reward choice + +We will see how important it is to set a proper reward, playing with two hyperparameters, ie: + +* The **discount factor** - for the discounted rewards +* The **reward expression**, as a function of the **production** - the hyperparameter being the function itself + +The results of a well-chosen reward can lead to significant improvements ! Our bot, only trained for conquering the map as fast as possible, now systematically wins against the trained bot (that applies heuristics). This is best exemplified by the 3 games below. + +

+ + + +

+ +## Devising the reward + +The reward is complex to devise since we take **multiple actions** at each turn, and we have to compute the reward for **each of these individual actions**. + +Below is an insightful illustration to understand the process. The number written on each square are **the reward associated with the current action of the square**. Notice that, at each turn, **these reward are different for each square**, and that, when a square is about to conquer another adjacent square, **the reward for its action is high**. + +It is even more higher as the square is more productive. + +> HINT: highly productive square have a brighter background, and the poorly productive have a darker one. + +Observe how the rewards evolve over time: there is already a **discount factor** applied, because we encourage (/reward) action **that will eventually lead to a reward over time**. Indeed, the `STILL` squares are earning rewards ! + +
+ +
    +
  • + reward1 +
    Discount = 0.6
    +
  • +
+ +## Understanding the discount factor + +To better understand the discount factor, let's push it to its **limits**, and look at the corresponding reward for the exact same game. + +* On the left, notice that when the discount factor is set to 0, only the moves that conquer a square qre rewarded. This means that the `STILL` action for a square never gets rewarded - which is undesirable. +* On the over end, with a discount rate of 0.9, the rewards tend to be **overall much higher**. Yet this excessively uniform pattern **doesn't favor much the actions that are actually good**. Too many actions are rewarded, even though they were potentially not efficient. + +As expected, these reward strategies fare badly compared to a more balanced discount factor. See below the comparison. + +
    +
  • + reward2 +
    Discount = 0.0
    +
  • +
  • + reward3 +
    Discount = 0.9
    +
  • +
+ +## Variation of the raw reward + +Each reward is computed according to the production of the conquered square, and then "backtracked" to the actions that lead to this reward. + +But should this reward be proportional to the production ? Wouldn't it be better to make it **proportional to the square of the production** ? Or even to a higher power ? + +Indeed, we want to strongly encourage our bot to conquer highly productive square, and a way to enforce efficiently this learning is by **giving significantly greater rewards for the highly productive square**. + +All the example before had reward proportional to the power 4 of the production. But let's look for a **power 2** and a **linear** reward. + +
    +
  • + reward4 +
    Power: 2 (Discount = 0.6)
    +
  • +
  • + reward5 +
    Power: 1 (Discount = 0.6)
    +
  • +
+ +### The ratio changes + +Let's **extract one frame of the above**. (*see gifs below*) Let's not focus on the absolute value of the rewards, but rather on **the ratio between the rewards of different actions**. + +The two actions that we compare here are: + +* The square on the top left that conquers its left neighbour (1) +* The square on the bottom right that conquers its above neighbour (2) + +We would want action (1) to be better reward than action (2). Indeed look at the background color of the conquered square. The conquered square in (1) is **brighter** than the conquered square in (2) and therefore **more productive**. + +In all cases, `reward(1) > reward(2)`. But if we look at the ratio (*see gifs below*), we have, from left to right: + +* 0.65/0.24 = 2.7 +* 0.93/0.49 = 1.8 +* 1.1/0.7 = 1.5 + +Which illustrates that, the **higher the exponent** for the reward, **the greater the difference between the reward** of good and very good actions. + +
    +
  • + reward1-bis +
    Power: 4 (D = 0.6)
    +
  • +
  • + reward4-bis +
    Power: 2 (D = 0.6)
    +
  • +
  • + reward5-bis +
    Power: 1 (D = 0.6)
    +
  • +
+ +## The performance + +According to the choice of reward, the training can be much slower, or even converge to a worse equilibrium. We should keep this in mind as we explore new strategies in the future. + +
+ +

+performance +

+ +
+ +## Scaling up + +What about the results on a larger map ? + +Our trained Bot **still wins** all the games against the OpponentBot when we increase the map size. + +
    +
  • + +
  • +
  • + +
  • +
  • + +
  • +
+ +However, we notice that: + +* This solution is too long to compute for each square individually + * Maybe we should only apply it for **the squares on the border** (and find another strategy for the squares in the center) + * We could gain time if we made **only one call to the tensorflow session**. Besides, the extraction of the local game states would probably be faster on the tensorflow side. +* Squares in the middle have a **suboptimal behaviour** - seems like they tend to move to the left systematically. diff --git a/docs/_posts/2017-10-08-dijkstra.markdown b/docs/_posts/2017-10-08-dijkstra.markdown new file mode 100644 index 0000000..33627ec --- /dev/null +++ b/docs/_posts/2017-10-08-dijkstra.markdown @@ -0,0 +1,42 @@ +--- +layout: default +title: "The dijkstra algorithm" +date: 2017-10-08 16:30:00 +categories: main +--- + + + +# Expansion at the border + +As detailed in the previous blog articles, we train jointly the individual agents at the border of the map. As you can see below, we obtain +an agent that perform significantly well at small scale. Indeed, it has learnt to conquer the **highly productive squares** (the bright ones) **in priority**. + +

+ +

+ +To assess the confidence of our agent, we can look at the **entropy** of the learnt policy. For convenience, we implemented a interface that displays the **softmax probabilities** at time t as you click on the agent. +We can see the 5 NORTH-EAST-SOUTH-WEST-STILL probabilities associated with each move, and how the agent greedily selects them. + +

+ +

+ +# The Dijsktra algorithm + +## The power of Dijsktra + +We had dealt with the problem of border squares, learning with a neural network. + +The Dijsktra algorithm, which runs here in linear time, gives us the ability to handle the squares in the middle of the map: + +Now, only the borders's behaviour is determined by our trained policy. We adopt a **deterministic strategy for the interior of the map**. + +

+ + +

+ diff --git a/docs/index.html b/docs/index.html new file mode 100644 index 0000000..e7a48ab --- /dev/null +++ b/docs/index.html @@ -0,0 +1,20 @@ +--- +layout: default +title: {{ site.name }} +--- + +
+

Documentation

+ +

Blog Posts

+
    + {% for post in site.posts %} +
  • {{ post.title }} ({{ post.date | date_to_string }})
  • + {% endfor %} +
+ +
\ No newline at end of file diff --git a/networking/hlt_networking.py b/networking/hlt_networking.py index 5e79010..5433c67 100644 --- a/networking/hlt_networking.py +++ b/networking/hlt_networking.py @@ -1,41 +1,45 @@ +"""The HLT class to handle the connection""" import socket -from public.hlt import translate_cardinal, GameMap + +from public.hlt import GameMap, translate_cardinal class HLT: + """The HLT class to handle the connection""" + def __init__(self, port): - _connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - _connection.connect(('localhost', port)) + connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + connection.connect(('localhost', port)) print('Connected to intermediary on port #' + str(port)) - self._connection = _connection + self.connection = connection def get_string(self): - newString = "" + new_string = "" buffer = '\0' while True: - buffer = self._connection.recv(1).decode('ascii') + buffer = self.connection.recv(1).decode('ascii') if buffer != '\n': - newString += str(buffer) + new_string += str(buffer) else: - return newString + return new_string - def sendString(self, s): + def send_string(self, s): s += '\n' - self._connection.sendall(bytes(s, 'ascii')) + self.connection.sendall(bytes(s, 'ascii')) def get_init(self): - myID = int(self.get_string()) + my_id = int(self.get_string()) game_map = GameMap(self.get_string(), self.get_string(), self.get_string()) - return myID, game_map + return my_id, game_map def send_init(self, name): - self.sendString(name) + self.send_string(name) def send_frame(self, moves): - self.sendString(' '.join( + self.send_string(' '.join( str(move.square.x) + ' ' + str(move.square.y) + ' ' + str(translate_cardinal(move.direction)) for move in moves)) - # - # if __name__ =="__main__": - # HLT(2000) + def send_frame_custom(self, moves): + self.send_string(' '.join( + str(x) + ' ' + str(y) + ' ' + str(translate_cardinal(direction)) for (x, y), direction in moves)) diff --git a/networking/kill.sh b/networking/kill.sh new file mode 100755 index 0000000..31f00df --- /dev/null +++ b/networking/kill.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +kill $(ps aux | grep python | grep -v start_game.py | grep $1| awk '{print $2}'); + +kill $(ps aux | grep python | grep -v pipe_socket_translator.py | grep $1| awk '{print $2}'); \ No newline at end of file diff --git a/networking/pipe_socket_translator.py b/networking/pipe_socket_translator.py index 1f324a7..2f11d16 100644 --- a/networking/pipe_socket_translator.py +++ b/networking/pipe_socket_translator.py @@ -1,12 +1,11 @@ +""" +To be launched by the Halite program as an intermediary, +in order to enable a pipe player to join. +""" import socket -import sys, traceback -import logging - -# logging.basicConfig(filename='example.log', level=logging.DEBUG) +import sys try: - # Connect - # logging.warning("connecting") socket_ = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) socket_.bind(('localhost', int(sys.argv[1]))) # This is where the port is selected @@ -14,50 +13,47 @@ connection, _ = socket_.accept() - # IO Functions - def sendStringPipe(toBeSent): - sys.stdout.write(toBeSent + '\n') + def send_string_pipe(to_be_sent): + sys.stdout.write(to_be_sent + '\n') sys.stdout.flush() - def getStringPipe(): - str = sys.stdin.readline().rstrip('\n') - return (str) + def get_string_pipe(): + str_pipe = sys.stdin.readline().rstrip('\n') + return str_pipe - def sendStringSocket(toBeSent): - global connection - toBeSent += '\n' - connection.sendall(bytes(toBeSent, 'ascii')) + def send_string_socket(to_be_sent): + to_be_sent += '\n' + connection.sendall(bytes(to_be_sent, 'ascii')) - def getStringSocket(): - global connection - newString = "" + def get_string_socket(): + new_string = "" buffer = '\0' while True: buffer = connection.recv(1).decode('ascii') if buffer != '\n': - newString += str(buffer) + new_string += str(buffer) else: - return newString + return new_string while True: # Handle Init IO - sendStringSocket(getStringPipe()) # Player ID - sendStringSocket(getStringPipe()) # Map Dimensions - sendStringSocket(getStringPipe()) # Productions - sendStringSocket(getStringPipe()) # Starting Map - sendStringPipe(getStringSocket()) # Player Name / Ready Response + send_string_socket(get_string_pipe()) # Player ID + send_string_socket(get_string_pipe()) # Map Dimensions + send_string_socket(get_string_pipe()) # Productions + send_string_socket(get_string_pipe()) # Starting Map + send_string_pipe(get_string_socket()) # Player Name / Ready Response # Run Frame Loop - while (getStringPipe() == 'Get map and play!'): # while True: - sendStringSocket('Get map and play!') - sendStringSocket(getStringPipe()) # Frame Map - sendStringPipe(getStringSocket()) # Move List - sendStringSocket('Stop playing!') + while get_string_pipe() == 'Get map and play!': # while True: + send_string_socket('Get map and play!') + send_string_socket(get_string_pipe()) # Frame Map + send_string_pipe(get_string_socket()) # Move List + send_string_socket('Stop playing!') -except Exception as e: +except ConnectionError as e: # logging.warning(traceback.format_exc()) pass diff --git a/networking/runGame.bat b/networking/runGame.bat index 7d9bd31..8b54593 100644 --- a/networking/runGame.bat +++ b/networking/runGame.bat @@ -1 +1 @@ -.\halite.exe -d "30 30" "python MyBot.py" "python RandomBot.py" +.\halite.exe -d "30 30" "python MyBot.py" "python RandomStrategy.py" diff --git a/networking/runGame.sh b/networking/runGame.sh index 275e549..61c4df4 100755 --- a/networking/runGame.sh +++ b/networking/runGame.sh @@ -1,7 +1,7 @@ #!/bin/bash if hash python3 2>/dev/null; then - ./halite -d "30 30" "python3 MyBot.py" "python3 RandomBot.py" + ./halite -d "30 30" "python3 MyBot.py" "python3 RandomStrategy.py" else - ./halite -d "30 30" "python MyBot.py" "python RandomBot.py" + ./halite -d "30 30" "python MyBot.py" "python RandomStrategy.py" fi diff --git a/networking/runGameDebugConfig.sh b/networking/runGameDebugConfig.sh deleted file mode 100755 index fd1cdcd..0000000 --- a/networking/runGameDebugConfig.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -if hash python3 2>/dev/null; then - kill $(ps aux | grep python | grep $1| awk '{print $2}'); ./public/halite -j -z 25 -n 1 -x 25 -t -d "10 10" "python3 networking/pipe_socket_translator.py $1"; -fi diff --git a/networking/start_game.py b/networking/start_game.py new file mode 100644 index 0000000..cee7983 --- /dev/null +++ b/networking/start_game.py @@ -0,0 +1,81 @@ +"""The start_game function to launch the halite.exe""" +import subprocess +import argparse +import os + + +def start_game(port=2000, width=10, height=10, max_strength=25, max_turn=25, max_game=1, + silent_bool=True, timeout=True, quiet=True, + n_pipe_players=1, slave_players=None): + """ + The start_game function to launch the halite.exe. + Execute with the -h option for help. + """ + path_to_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) + for i in range(n_pipe_players): + subprocess.call([path_to_root + "/networking/kill.sh", str(port + i)]) # Free the necessary ports + # subprocess.call([path_to_root + "/networking/kill.sh", str(port+1)]) # TODO automatic call to subprocess + halite = path_to_root + '/public/halite ' + dimensions = '-d "' + str(height) + ' ' + str(width) + '" ' + + max_strength = '-z ' + str(max_strength) + ' ' + max_turn = '-x ' + str(max_turn) + ' ' + max_game = '-g ' + str(max_game) + ' ' + silent_bool = '-j ' if silent_bool else '' + timeout = '-t ' if timeout else '' + quiet = '-q ' if quiet else '' + pipe_players = [ + "python3 " + path_to_root + "/networking/pipe_socket_translator.py " + str(port + i) for i in + range(n_pipe_players) + ] + slave_players = [ + "python3 " + path_to_root + "/public/" + slave_player + ' slave' for slave_player in slave_players + ] if slave_players is not None else [] # slave is the slave argument + players = pipe_players + slave_players + # "python3 " + path_to_root + "/networking/pipe_socket_translator.py " + str(port+1) + n_player = '' if len(players) > 1 else '-n 1 ' + + players = '"' + '" "'.join(players) + '"' + print(players) + print("Launching process") + + subprocess.call( + halite + dimensions + n_player + max_strength + max_turn + silent_bool + timeout + quiet + max_game + players, + shell=True) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--port", type=int, + help="the port for the simulation - Useless if there are no pipe_players", + default=2000) + parser.add_argument("-t", "--timeout", help="Doens't timeout if you set this flag is set", + action="store_true", default=False) + parser.add_argument("-j", "--silent", help="Doesn't print *.hlt file", + action="store_true", default=False) + parser.add_argument("-q", "--quiet", help="Doesn't output information to the console", + action="store_true", default=False) + parser.add_argument("-s", "--strength", help="The max strength of the squares, if needed", + type=int, default=25) + parser.add_argument("-dw", "--width", help="The width of the game", + type=int, default=10) + parser.add_argument("-dh", "--height", help="The height of the game", + type=int, default=10) + parser.add_argument("-m", "--maxturn", help="The total number of turns per game (maximum)", + type=int, default=25) + parser.add_argument("-g", "--maxgame", help="The total number of games to play", + type=int, default=1) # -1 for infinite game + parser.add_argument("-pp", "--n_pipe_players", + help="The number of pipe players. You need to handle these players yourself. " + "Each of them has a port assigned.", + type=int, default=0) + parser.add_argument("-sp", "--slave_players", + help="The slave players. Handled by the halite.exe. " + "You should write one of these two strings: " + "'MyBot.py' or 'OpponentBot.py' (multiple time if desired) ", + nargs='+', default=[]) + args = parser.parse_args() + start_game(port=args.port, width=args.width, height=args.height, max_strength=args.strength, max_turn=args.maxturn, + silent_bool=args.silent, timeout=args.timeout, max_game=args.maxgame, quiet=args.quiet, + n_pipe_players=args.n_pipe_players, + slave_players=args.slave_players) diff --git a/public/MyBot.py b/public/MyBot.py index 0a9ccbe..8b196eb 100644 --- a/public/MyBot.py +++ b/public/MyBot.py @@ -1,41 +1,30 @@ +"""The MyBot.py file that executes the TrainedBot.py""" +import os import sys +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +try: + from public.models.strategy.TrainedStrategy import TrainedStrategy + from networking.hlt_networking import HLT +except: + raise + mode = 'server' if (len(sys.argv) == 1) else 'local' -mode = 'local' # TODO remove forcing -if mode == 'server': # 'server' mode +if mode == 'server' or sys.argv[1] == 'slave': # 'server' mode import hlt else: # 'local' mode - from networking.hlt_networking import HLT - port = int(sys.argv[1]) if len(sys.argv) > 1 else 2000 hlt = HLT(port=port) -from public.models.bot.trainedBot import TrainedBot - -import tensorflow as tf - -tf.reset_default_graph() - -with tf.device("/cpu:0"): - with tf.variable_scope('global'): - bot = TrainedBot() - global_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='global') - saver = tf.train.Saver(global_variables) - init = tf.global_variables_initializer() - -with tf.Session() as sess: - sess.run(init) - try: - saver.restore(sess, 'models/' + bot.agent.name) - except Exception: - print("Model not found - initiating new one") +bot = TrainedStrategy() +bot.init_session() - while True: - myID, game_map = hlt.get_init() - bot.setID(myID) - hlt.send_init("OpponentBot") +while True: + my_id, game_map = hlt.get_init() + hlt.send_init("MyBot") + bot.set_id(my_id) - while (mode == 'server' or hlt.get_string() == 'Get map and play!'): - game_map.get_frame(hlt.get_string()) - moves = bot.compute_moves(game_map, sess) - hlt.send_frame(moves) + while mode == 'server' or hlt.get_string() == 'Get map and play!': + game_map.get_frame(hlt.get_string()) + moves = bot.compute_moves(game_map) + hlt.send_frame(moves) diff --git a/public/OpponentBot.py b/public/OpponentBot.py index 55d8629..fcac0ad 100644 --- a/public/OpponentBot.py +++ b/public/OpponentBot.py @@ -1,24 +1,29 @@ +"""The Opponent.py file that executes the ImprovedBot.py""" +import os import sys -mode = 'server' if (len(sys.argv) == 1) else 'local' -mode = 'local' # TODO remove forcing +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +try: + from public.models.strategy.ImprovedStrategy import ImprovedStrategy + from networking.hlt_networking import HLT +except: + raise -if mode == 'server': # 'server' mode +mode = 'server' if (len(sys.argv) == 1) else 'local' +if mode == 'server' or sys.argv[1] == 'slave': # 'server' mode import hlt else: # 'local' mode - from networking.hlt_networking import HLT - port = int(sys.argv[1]) if len(sys.argv) > 1 else 2000 hlt = HLT(port=port) -from public.models.bot.improvedBot import ImprovedBot +bot = ImprovedStrategy() while True: - myID, game_map = hlt.get_init() + my_id, game_map = hlt.get_init() hlt.send_init("OpponentBot") - bot = ImprovedBot(myID) + bot.set_id(my_id) - while (mode == 'server' or hlt.get_string() == 'Get map and play!'): + while mode == 'server' or hlt.get_string() == 'Get map and play!': game_map.get_frame(hlt.get_string()) moves = bot.compute_moves(game_map) hlt.send_frame(moves) diff --git a/public/hlt.py b/public/hlt.py index 26093ed..c4aad22 100644 --- a/public/hlt.py +++ b/public/hlt.py @@ -1,3 +1,4 @@ +"""The original but corrected hlt.py file for communication with halite.""" import sys from collections import namedtuple from itertools import chain, zip_longest @@ -24,6 +25,8 @@ def opposite_cardinal(direction): class GameMap: + """The GameMap on which to play.""" + def __init__(self, size_string, production_string, map_string=None): self.width, self.height = tuple(map(int, size_string.split())) self.production = tuple( @@ -57,12 +60,14 @@ def __iter__(self): return chain.from_iterable(self.contents) def neighbors(self, square, n=1, include_self=False): - "Iterable over the n-distance neighbors of a given square. For single-step neighbors, the enumeration index provides the direction associated with the neighbor." + """Iterable over the n-distance neighbors of a given square. + For single-step neighbors, the enumeration index provides + the direction associated with the neighbor. + """ assert isinstance(include_self, bool) assert isinstance(n, int) and n > 0 if n == 1: - combos = ((0, -1), (1, 0), (0, 1), (-1, 0), (0, - 0)) # NORTH, EAST, SOUTH, WEST, STILL ... matches indices provided by enumerate(game_map.neighbors(square)) + combos = ((0, -1), (1, 0), (0, 1), (-1, 0), (0, 0)) else: combos = ((dx, dy) for dy in range(-n, n + 1) for dx in range(-n, n + 1) if abs(dx) + abs(dy) <= n) return (self.contents[(square.y + dy) % self.height][(square.x + dx) % self.width] for dx, dy in combos if @@ -96,9 +101,9 @@ def get_string(): def get_init(): - playerID = int(get_string()) + player_id = int(get_string()) m = GameMap(get_string(), get_string()) - return playerID, m + return player_id, m def send_init(name): @@ -106,7 +111,7 @@ def send_init(name): def translate_cardinal(direction): - "Translate direction constants used by this Python-based bot framework to that used by the official Halite game environment." + "Beware the direction are changed! Important for visualization" return (direction + 1) % 5 @@ -114,3 +119,17 @@ def send_frame(moves): send_string(' '.join( str(move.square.x) + ' ' + str(move.square.y) + ' ' + str(translate_cardinal(move.direction)) for move in moves)) + + +def send_frame_custom(moves): + send_string(' '.join( + str(x) + ' ' + str(y) + ' ' + str(translate_cardinal(direction)) for (x, y), direction in moves)) + + +def format_moves(game_map, moves): + moves_to_send = [] + for y in range(len(game_map.contents)): + for x in range(len(game_map.contents[0])): + if moves[y][x] != -1: + moves_to_send += [Move(game_map.contents[y][x], moves[y][x])] + return moves_to_send diff --git a/public/models/__init__.py b/public/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/public/models/agent/Agent.py b/public/models/agent/Agent.py new file mode 100644 index 0000000..f4ba78d --- /dev/null +++ b/public/models/agent/Agent.py @@ -0,0 +1,11 @@ +"""The Agent general class""" + + +class Agent: + """The Agent general class""" + + def choose_actions(self, sess, local_game_state_n, train=True): + pass + + def update_agent(self, sess, states, moves, rewards): + pass diff --git a/public/models/agent/VanillaAgent.py b/public/models/agent/VanillaAgent.py new file mode 100644 index 0000000..eb23968 --- /dev/null +++ b/public/models/agent/VanillaAgent.py @@ -0,0 +1,74 @@ +"""The Vanilla Agent""" + +import numpy as np +import tensorflow as tf +import tensorflow.contrib.slim as slim + +from public.models.agent.Agent import Agent + +class VanillaAgent(Agent): + """The Vanilla Agent""" + + def __init__(self, lr=1e-4, s_size=50, h_size=200, a_size=5): # all these are optional ? + super(VanillaAgent, self).__init__() + + # These lines established the feed-forward part of the network. The agent takes a state and produces an action. + self.state_in = tf.placeholder(shape=[None, s_size], dtype=tf.float32) + + hidden = slim.fully_connected(self.state_in, h_size, activation_fn=tf.nn.relu) + + self.policy = slim.fully_connected(hidden, a_size, activation_fn=tf.nn.softmax) + self.predict = tf.argmax(self.policy, 1) + + # The next six lines establish the training proceedure. We feed the reward and predict into the network + # to compute the loss, and use it to update the network. + self.reward_holder = tf.placeholder(shape=[None], dtype=tf.float32) + self.action_holder = tf.placeholder(shape=[None], dtype=tf.int32) + + self.indexes = tf.range(0, tf.shape(self.policy)[0]) * tf.shape(self.policy)[1] + self.action_holder + self.responsible_outputs = tf.gather(tf.reshape(self.policy, [-1]), self.indexes) #TODO... + + loss = -tf.reduce_mean(tf.log(self.responsible_outputs) * self.reward_holder) + + self.tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name) + self.gradients = tf.gradients(loss, self.tvars) + + self.gradient_holders = [] + for idx in range(len(self.tvars)): + placeholder = tf.placeholder(tf.float32, name=str(idx) + '_holder') + self.gradient_holders.append(placeholder) + + global_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'global') + optimizer = tf.train.AdamOptimizer(learning_rate=lr) + + self.update_global = optimizer.apply_gradients(zip(self.gradient_holders, global_vars)) # self.tvars + + def get_policy(self, sess, state): + return sess.run(self.policy, feed_dict={self.state_in: [state.reshape(-1)]}) + + def choose_action(self, sess, state, train=True): + # Here the state is normalized ! + if train: # keep randomness + a_dist = sess.run(self.policy, feed_dict={self.state_in: [state.reshape(-1)]}) + a = np.random.choice(a_dist[0], p=a_dist[0]) + a = np.argmax(a_dist == a) + else: # act greedily + a = sess.run(self.predict, feed_dict={self.state_in: [state.reshape(-1)]}) + return a + + def choose_actions(self, sess, local_game_state_n, train=True): + """Choose all actions using one call to tensorflow""" + if train: + actions = sess.run(self.policy, feed_dict={self.state_in: local_game_state_n}) + actions = [np.argmax(action == np.random.choice(action, p=action)) for action in actions] + else: + actions = sess.run(self.predict, feed_dict={self.state_in: local_game_state_n}) + return actions + + def update_agent(self, sess, states, moves, rewards): + feed_dict = {self.state_in: states, + self.action_holder: moves, + self.reward_holder: rewards} + grads = sess.run(self.gradients, feed_dict=feed_dict) + feed_dict = dict(zip(self.gradient_holders, grads)) + _ = sess.run(self.update_global, feed_dict=feed_dict) diff --git a/public/models/agent/__init__.py b/public/models/agent/__init__.py index 849b75f..e69de29 100644 --- a/public/models/agent/__init__.py +++ b/public/models/agent/__init__.py @@ -1 +0,0 @@ -# TODO: import via the agent package diff --git a/public/models/agent/agent.py b/public/models/agent/agent.py deleted file mode 100644 index 323ed5d..0000000 --- a/public/models/agent/agent.py +++ /dev/null @@ -1,28 +0,0 @@ -import numpy as np -import tensorflow as tf -import tensorflow.contrib.slim as slim - -from train.reward import localStateFromGlobal - - -class Agent: - def __init__(self, name, experience): - self.name = name - self.experience = experience - if self.experience is not None: - try: - self.experience.metric = np.load('models/' + self.name + '.npy') - except: - print("New metric file created") - self.experience.metric = np.array([]) - - def choose_actions(self, sess, game_state, debug=False): - moves = np.zeros_like(game_state[0], dtype=np.int64) - 1 - for y in range(len(game_state[0])): - for x in range(len(game_state[0][0])): - if (game_state[0][y][x] == 1): - moves[y][x] = self.choose_action(sess, localStateFromGlobal(game_state, x, y), debug=debug) - return moves - - def update_agent(self, sess): - pass diff --git a/public/models/agent/vanillaAgent.py b/public/models/agent/vanillaAgent.py deleted file mode 100644 index 5889ff0..0000000 --- a/public/models/agent/vanillaAgent.py +++ /dev/null @@ -1,65 +0,0 @@ -import numpy as np -import tensorflow as tf -import tensorflow.contrib.slim as slim - -from train.reward import localStateFromGlobal -from public.models.agent.agent import Agent - - -class VanillaAgent(Agent): - def __init__(self, experience, lr, s_size, a_size, h_size): # all these are optional ? - super(VanillaAgent, self).__init__('vanilla', experience) - - # These lines established the feed-forward part of the network. The agent takes a state and produces an action. - self.state_in = tf.placeholder(shape=[None, s_size], dtype=tf.float32) - - hidden = slim.fully_connected(self.state_in, h_size, activation_fn=tf.nn.relu) - - self.policy = slim.fully_connected(hidden, a_size, activation_fn=tf.nn.softmax) - self.predict = tf.argmax(self.policy, 1) - - # The next six lines establish the training proceedure. We feed the reward and predict into the network - # to compute the loss, and use it to update the network. - self.reward_holder = tf.placeholder(shape=[None], dtype=tf.float32) - self.action_holder = tf.placeholder(shape=[None], dtype=tf.int32) - - self.indexes = tf.range(0, tf.shape(self.policy)[0]) * tf.shape(self.policy)[1] + self.action_holder - self.responsible_outputs = tf.gather(tf.reshape(self.policy, [-1]), self.indexes) - if experience is not None: - loss = -tf.reduce_mean(tf.log(self.responsible_outputs) * self.reward_holder) - - self.tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name) - self.gradients = tf.gradients(loss, self.tvars) - - self.gradientHolders = [] - for idx, var in enumerate(self.tvars): - placeholder = tf.placeholder(tf.float32, name=str(idx) + '_holder') - self.gradientHolders.append(placeholder) - - global_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'global') - optimizer = tf.train.AdamOptimizer(learning_rate=lr) - - self.updateGlobal = optimizer.apply_gradients(zip(self.gradientHolders, global_vars)) # self.tvars - - def choose_action(self, sess, state, frac_progress=1.0, debug=False): # it only a state, not the game state... - if (np.random.uniform() >= frac_progress): - a = np.random.choice(range(5)) - else: - a_dist = sess.run(self.policy, feed_dict={self.state_in: [state.reshape(-1)]}) - a = np.random.choice(a_dist[0], p=a_dist[0]) - a = np.argmax(a_dist == a) - if debug: - a = sess.run(self.predict, feed_dict={self.state_in: [state.reshape(-1)]}) - return a - - def update_agent(self, sess): - # batch_size = min(int(len(self.moves)/2),128) # Batch size - # indices = np.random.randint(len(self.moves)-1, size=batch_size) - states, moves, rewards = self.experience.batch(512) - - feed_dict = {self.state_in: states, - self.action_holder: moves, - self.reward_holder: rewards} - grads = sess.run(self.gradients, feed_dict=feed_dict) - feed_dict = dict(zip(self.gradientHolders, grads)) - _ = sess.run(self.updateGlobal, feed_dict=feed_dict) diff --git a/public/models/bot/__init__.py b/public/models/bot/__init__.py deleted file mode 100644 index 1a1aa6a..0000000 --- a/public/models/bot/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: import via the bot package diff --git a/public/models/bot/bot.py b/public/models/bot/bot.py deleted file mode 100644 index d298dae..0000000 --- a/public/models/bot/bot.py +++ /dev/null @@ -1,9 +0,0 @@ -class Bot: - def __init__(self, myID=1): - self.myID = myID - - def compute_moves(self, game_map, sess=None): - pass - - def setID(self, myID): - self.myID = myID diff --git a/public/models/bot/randomBot.py b/public/models/bot/randomBot.py deleted file mode 100644 index 844e41e..0000000 --- a/public/models/bot/randomBot.py +++ /dev/null @@ -1,13 +0,0 @@ -import random - -from public.models.bot.bot import Bot -from public.hlt import NORTH, EAST, SOUTH, WEST, STILL, Move - - -class RandomBot(Bot): - def __init__(self, myID): - super(RandomBot, self).__init__(myID) - - def compute_moves(self, game_map, sess=None): - [Move(square, random.choice((NORTH, EAST, SOUTH, WEST, STILL))) for square in game_map if - square.owner == self.myID] diff --git a/public/models/bot/trainedBot.py b/public/models/bot/trainedBot.py deleted file mode 100644 index 59531b5..0000000 --- a/public/models/bot/trainedBot.py +++ /dev/null @@ -1,18 +0,0 @@ -from public.models.agent.vanillaAgent import VanillaAgent -from public.models.bot.bot import Bot -from train.reward import getGameState, formatMoves -import tensorflow as tf - - -class TrainedBot(Bot): - def __init__(self, myID=None): - lr = 1e-3; - s_size = 9 * 3; - a_size = 5; - h_size = 50 - self.agent = VanillaAgent(None, lr, s_size, a_size, h_size) - super(TrainedBot, self).__init__(myID) - - def compute_moves(self, game_map, sess=None): - game_state = getGameState(game_map, self.myID) - return formatMoves(game_map, self.agent.choose_actions(sess, game_state)) diff --git a/public/models/bot/improvedBot.py b/public/models/strategy/ImprovedStrategy.py similarity index 53% rename from public/models/bot/improvedBot.py rename to public/models/strategy/ImprovedStrategy.py index eaea6d3..e35d2ce 100644 --- a/public/models/bot/improvedBot.py +++ b/public/models/strategy/ImprovedStrategy.py @@ -1,19 +1,18 @@ +"""The Improved Bot""" import random -from public.models.bot.bot import Bot -from public.hlt import NORTH, EAST, SOUTH, WEST, STILL, Move +from public.hlt import Move, NORTH, STILL, WEST +from public.models.strategy.Strategy import Strategy -class ImprovedBot(Bot): - def __init__(self, myID): - super(ImprovedBot, self).__init__(myID) - - def compute_moves(self, game_map, sess=None): +class ImprovedStrategy(Strategy): + def compute_moves(self, game_map): + """Compute the moves given a game_map""" moves = [] for square in game_map: - if square.owner == self.myID: + if square.owner == self.my_id: for direction, neighbor in enumerate(game_map.neighbors(square)): - if neighbor.owner != self.myID and neighbor.strength < square.strength: + if neighbor.owner != self.my_id and neighbor.strength < square.strength: moves += [Move(square, direction)] if square.strength < 5 * square.production: moves += [Move(square, STILL)] diff --git a/public/models/strategy/RandomStrategy.py b/public/models/strategy/RandomStrategy.py new file mode 100644 index 0000000..e0f0424 --- /dev/null +++ b/public/models/strategy/RandomStrategy.py @@ -0,0 +1,12 @@ +"""The Random Bot""" +import random + +from public.hlt import EAST, Move, NORTH, SOUTH, STILL, WEST +from public.models.strategy.Strategy import Strategy + + +class RandomStrategy(Strategy): + def compute_moves(self, game_map): + """Compute the moves given a game_map""" + return [Move(square, random.choice((NORTH, EAST, SOUTH, WEST, STILL))) for square in game_map if + square.owner == self.my_id] diff --git a/public/models/strategy/Strategy.py b/public/models/strategy/Strategy.py new file mode 100644 index 0000000..69a932b --- /dev/null +++ b/public/models/strategy/Strategy.py @@ -0,0 +1,7 @@ +"""The General Bot class""" +class Strategy: + def compute_moves(self, game_map): + pass + + def set_id(self, my_id): + self.my_id = my_id diff --git a/public/models/strategy/TrainedStrategy.py b/public/models/strategy/TrainedStrategy.py new file mode 100644 index 0000000..b94465b --- /dev/null +++ b/public/models/strategy/TrainedStrategy.py @@ -0,0 +1,119 @@ +"""The Trained Bot""" +import json +import os +import warnings + +import numpy as np +import tensorflow as tf +from tensorflow.python.framework.errors_impl import InvalidArgumentError + +from public.hlt import format_moves +from public.models.agent.VanillaAgent import VanillaAgent +from public.models.strategy.Strategy import Strategy +from public.state.state import get_game_state, State1 +from public.util.dijkstra import build_graph_from_state, dijkstra +from public.util.path import move_to, path_to +from train.experience import ExperienceVanilla +from train.reward.reward import Reward + + +class TrainedStrategy(Strategy): + """The trained strategy""" + + def __init__(self, tf_scope='global'): + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + tf.reset_default_graph() + warnings.filterwarnings("ignore") + + config = open(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../strategy.json'))).read() + config = json.loads(config) + self.name = config["saving_name"] + self.state = State1(scope=config["agent"]["scope"]) + self.reward = Reward(state=self.state) + self.experience = ExperienceVanilla(self.state.local_size, self.name) + with tf.variable_scope(tf_scope): + self.agent1 = VanillaAgent(s_size=self.state.local_size, h_size=config["agent"]["h_size"]) + + def init_session(self, sess=None, saver=None): + if sess is None: + global_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='global') + self.saver = tf.train.Saver(global_variables) + init = tf.global_variables_initializer() + self.sess = tf.Session() + self.sess.run(init) + try: + self.saver.restore(self.sess, os.path.abspath( + os.path.join(os.path.dirname(__file__), '..')) + + '/variables/' + self.name + '/' + + self.name) + except InvalidArgumentError: + print("Model not found - initiating new one") + else: + self.sess = sess + self.saver = saver + + def set_id(self, my_id): + super(TrainedStrategy, self).set_id(my_id) + self.agent1_moves = [] + self.agent1_game_states = [] + + def compute_moves(self, game_map, train=False): + """Compute the moves given a game_map""" + game_state = get_game_state(game_map, self.my_id) + g = build_graph_from_state(game_state[0]) + dist_dict, closest_dict = dijkstra(g.g, 0) + self.agent1_game_states += [game_state] + self.agent1_moves += [np.zeros_like(game_state[0], dtype=np.int64) - 1] + self.agent1_local_game_states = np.array([]).reshape(0, self.state.local_size) + dijkstra_moves = np.zeros_like(game_state[0], dtype=np.int64) - 1 + agent1_positions = [] + for (y, x), k in np.ndenumerate(game_state[0]): + if k == 1: + if (y, x) in dist_dict and dist_dict[(y, x)] in [1, 2]: + agent1_positions += [(y, x)] + game_state_n = self.state.get_local_and_normalize(game_state, x, y).reshape(1, + self.state.local_size) + self.agent1_local_game_states = np.concatenate((self.agent1_local_game_states, game_state_n), + axis=0) + else: + if game_state[1][y][x] > 10: # Set a minimum strength + y_t, x_t = y, x + y_t, x_t = closest_dict[(y_t, x_t)] + dijkstra_moves[y][x] = move_to( + path_to((x, y), (x_t, y_t), len(game_state[0][0]), len(game_state[0]))) + actions = self.agent1.choose_actions(self.sess, self.agent1_local_game_states, train) + for (y, x), d in zip(agent1_positions, actions): + self.agent1_moves[-1][y][x] = d + + return format_moves(game_map, -(self.agent1_moves[-1] * dijkstra_moves)) + + def add_episode(self): + all_states, all_moves, all_rewards = self.reward.all_rewards_function(self.agent1_game_states, + self.agent1_moves) + self.experience.add_episode(self.agent1_game_states, all_states, all_moves, all_rewards) + + def update_agent(self): + train_states, train_moves, train_rewards = self.experience.batch() + self.agent1.update_agent(self.sess, train_states, train_moves, train_rewards) + + def save(self): + directory = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) + '/variables/' + self.name + if not os.path.exists(directory): + print("Creating directory named :" + self.name) + os.makedirs(directory) + self.saver.save(self.sess, directory + '/' + self.name) + self.experience.save_metric(directory + '/' + self.name) + + def get_policies(self, game_states): + policies = [] + for game_state in game_states: + policies += [np.zeros(game_state[0].shape + (5,))] + for (y, x), k in np.ndenumerate(game_state[0]): + if k == 1: + policies[-1][y][x] = self.agent1.get_policy( + self.sess, self.state.get_local_and_normalize(game_state, x, y)) + return np.array(policies) + + def close(self): + """Close the tensorflow session""" + self.sess.close() diff --git a/public/models/variables/README.md b/public/models/variables/README.md new file mode 100644 index 0000000..d1b81c2 --- /dev/null +++ b/public/models/variables/README.md @@ -0,0 +1,3 @@ +# Variables + +Here are, for instance, the stored Tensorflow models, with the convention of using the name of the agent for **both the folder and the files**. diff --git a/public/models/visualize_score.py b/public/models/visualize_score.py deleted file mode 100644 index acf0367..0000000 --- a/public/models/visualize_score.py +++ /dev/null @@ -1,14 +0,0 @@ -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt - -rewards = [np.load('./models/vanilla.npy')] - -max_len = max([len(reward) for reward in rewards]) -for i in range(len(rewards)): - rewards[i] = np.append(rewards[i], np.repeat(np.nan, max_len - len(rewards[i]))) - -pd.DataFrame(np.array(rewards).T, columns=['vanilla']).rolling(100).mean().plot( - title="Weighted reward at each game. (Rolling average)") - -plt.show() diff --git a/public/state/state.py b/public/state/state.py new file mode 100644 index 0000000..df705b8 --- /dev/null +++ b/public/state/state.py @@ -0,0 +1,50 @@ +"""The state file""" +import numpy as np + +STRENGTH_SCALE = 255 +PRODUCTION_SCALE = 10 + + +class State: + def __init__(self, local_size): + self.local_size = local_size + + def get_local(self, game_state, x, y): + pass + + def get_local_and_normalize(self, game_state, x, y): + return self.get_local(game_state, x, y) / np.array([1, STRENGTH_SCALE, PRODUCTION_SCALE])[:, np.newaxis] + + +class State1(State): + def __init__(self, scope=1): + self.scope = scope + super(State1, self).__init__(local_size=3 * ((2 * scope + 1) ** 2)) + + def get_local(self, game_state, x, y): + # all the axes remain through the operation, because of range + return np.take(np.take(game_state, range(y - self.scope, y + self.scope + 1), axis=1, mode='wrap'), + range(x - self.scope, x + self.scope + 1), axis=2, mode='wrap').reshape(3, -1) + + +class State2(State): + def __init__(self, scope=2): + self.scope = scope + super(State2, self).__init__(local_size=3 * (2 * (scope ** 2) + 2 * scope + 1)) + + def get_local(self, game_state, x, y): + to_concat = () + for i in range(self.scope + 1): + slice_s = np.take(np.take(game_state, + range(x - (self.scope - i), x + (self.scope - i) + 1), axis=2, mode='wrap'), + [y - i, y + i] if i != 0 else y, axis=1, mode='wrap') + slice_s = slice_s.reshape(3, -1) + to_concat += (slice_s,) + return np.concatenate(to_concat, axis=1) + + +def get_game_state(game_map, my_id): + game_state = np.reshape( + [[(square.owner == my_id) + 0, square.strength, square.production] for square in game_map], + [game_map.height, game_map.width, 3]) + return np.swapaxes(np.swapaxes(game_state, 2, 0), 1, 2) diff --git a/public/strategy.json b/public/strategy.json new file mode 100644 index 0000000..f4ec35a --- /dev/null +++ b/public/strategy.json @@ -0,0 +1,11 @@ +{ + "agent":{ + "type":"vanilla", + "scope":2, + "h_size":200 + }, + "dijkstra":{ + "scope":2 + }, + "saving_name": "vanilla-scope-test" +} \ No newline at end of file diff --git a/public/util/dijkstra.py b/public/util/dijkstra.py new file mode 100644 index 0000000..244fa68 --- /dev/null +++ b/public/util/dijkstra.py @@ -0,0 +1,146 @@ +"""The dijkstra module""" +import numpy as np + + +class Prioritydictionary(dict): + """A priority dictionary""" + def __init__(self): + """Initialize Prioritydictionary by creating binary heap of + pairs (value,key). Note that changing or removing a dict entry + will not remove the old pair from the heap until it is found by + smallest() or until the heap is rebuilt.""" + self.__heap = [] + dict.__init__(self) + + def smallest(self): + """Find smallest item after removing deleted items from front of + heap.""" + if len(self) == 0: + raise IndexError("smallest of empty Prioritydictionary") + heap = self.__heap + while heap[0][1] not in self or self[heap[0][1]] != heap[0][0]: + last_item = heap.pop() + insertion_point = 0 + while 1: + small_child = 2 * insertion_point + 1 + if small_child + 1 < len(heap) and \ + heap[small_child] > heap[small_child + 1]: + small_child += 1 + if small_child >= len(heap) or last_item <= heap[small_child]: + heap[insertion_point] = last_item + break + heap[insertion_point] = heap[small_child] + insertion_point = small_child + return heap[0][1] + + def __iter__(self): + """Create destructive sorted iterator of Prioritydictionary.""" + + def iterfn(): + while len(self) > 0: + x = self.smallest() + yield x + del self[x] + + return iterfn() + + def __setitem__(self, key, val): + """Change value stored in dictionary and add corresponding pair + to heap. Rebuilds the heap if the number of deleted items gets + large, to avoid memory leakage.""" + dict.__setitem__(self, key, val) + heap = self.__heap + if len(heap) > 2 * len(self): + self.__heap = [(v, k) for k, v in self.items()] + self.__heap.sort() + # builtin sort probably faster than O(n)-time heapify + else: + new_pair = (val, key) + insertion_point = len(heap) + heap.append(None) + while insertion_point > 0 and \ + new_pair < heap[(insertion_point - 1) // 2]: + heap[insertion_point] = heap[(insertion_point - 1) // 2] + insertion_point = (insertion_point - 1) // 2 + heap[insertion_point] = new_pair + + def setdefault(self, key, val): + """Reimplement setdefault to pass through our customized __setitem__.""" + if key not in self: + self[key] = val + return self[key] + + +def dijkstra(g, start, end=None): + """The dijkstra algorithm""" + d = {} # dictionary of final distances + p = {} # dictionary of predecessors + q = Prioritydictionary() # estimated distances of non-final vertices + q[start] = 0 + + for v in q: + d[v] = q[v] + if v == end: + break + + for w in g[v]: + vw_length = d[v] + g[v][w] + if w in d: + if vw_length < d[w]: + raise ValueError("Dijkstra: found better path to already-final vertex") + elif w not in q or vw_length < q[w]: + q[w] = vw_length + p[w] = v + + return d, p + + +class Graph: + """The Graph object""" + def __init__(self): + self.g = {} + + def add_node(self, value): + if value not in self.g: + self.g[value] = {} + + def add_edge(self, from_node, to_node): + self.g[from_node][to_node] = 1 + self.g[to_node][from_node] = 1 + + def remove_node(self, value): + for to in self.g[value]: + self.g[to].pop(value, None) + self.g.pop(value, None) + + def update(self, new_state, previous_state): + for (i, j), k in np.ndenumerate(new_state - previous_state): + if k == 1: + self.add_node((i, j)) + for i1, j1 in [(i - 1, j), (i, j + 1), (i, j - 1), (i + 1, j)]: + if (i1, j1) in self.g: + self.add_edge((i, j), (i1, j1)) + elif k == -1: + self.remove_node((i, j)) + + +def build_graph_from_state(state): # The keys will be y, x, since we np.ndenumerate + """Build the graph from the state""" + def take(state, i, j): + return np.take(np.take(state, j, axis=1, mode='wrap'), i, axis=0, mode='wrap') + + g = Graph() + g.add_node(0) + for (i, j), k in np.ndenumerate(state): + if k == 1: + g.add_node((i, j)) + n, e, w, s = take(state, i - 1, j), take(state, i, j + 1), take(state, i, j - 1), take(state, i + 1, j) + if s: + g.add_node(((i + 1) % state.shape[0], j)) + g.add_edge((i, j), ((i + 1) % state.shape[0], j)) + if e: + g.add_node((i, (j + 1) % state.shape[1])) + g.add_edge((i, j), (i, (j + 1) % state.shape[1])) + if n * e * w * s == 0: + g.add_edge((i, j), 0) + return g diff --git a/public/util/path.py b/public/util/path.py new file mode 100644 index 0000000..dbf3bee --- /dev/null +++ b/public/util/path.py @@ -0,0 +1,32 @@ +"""The path to another point""" +import random + +from public.hlt import EAST, WEST, NORTH, SOUTH + + +def path_to(start, end, width, height): + """Given start = (x,y), end = (x,y), end return dx, dy""" + x1, y1 = start + x2, y2 = end + + def settle(p1, p2, modulo): + dp = min(abs(p1 - p2), modulo - abs(p1 - p2)) + if p1 < p2 and p2 - p1 != dp: # TODO contract formula + dp = -dp + elif p1 > p2 and p1 - p2 == dp: + dp = -dp + return dp + + return settle(x1, x2, width), settle(y1, y2, height) + + +def move_to(dxy): + """Move to the closest square given the tuple (dx, dy)""" + dx, dy = dxy + assert abs(dx) > 0 or abs(dy) > 0, "No closer move possible" + prob_east_west = abs(dx) / (abs(dx) + abs(dy)) + if random.uniform(0, 1) < prob_east_west: # Act east_west + move = EAST if dx > 0 else WEST + else: # Act north_south + move = SOUTH if dy > 0 else NORTH + return move diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e21a3b0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +tensorflow +coverage>=3.6 +pytest-cov +pytest-xdist +pytest-benchmark>=3.1 +coveralls +pylint>=1.6 +flask \ No newline at end of file diff --git a/src/core/Halite.cpp b/src/core/Halite.cpp index bcc4e8f..6bc0e19 100644 --- a/src/core/Halite.cpp +++ b/src/core/Halite.cpp @@ -389,14 +389,18 @@ GameStatistics Halite::runGame(std::vector * names_, unsigned int s stats.timeout_tags = timeout_tags; stats.timeout_log_filenames = std::vector(timeout_tags.size()); //Output gamefile. First try the replays folder; if that fails, just use the straight filename. - stats.output_filename = "Replays/" + std::to_string(id) + '-' + std::to_string(seed) + ".hlt"; + stats.output_filename = "visualize/hlt/" + std::to_string(id) + '-' + std::to_string(seed) + ".hlt"; if(!no_file_output){ try { output(stats.output_filename); } catch(std::runtime_error & e) { - stats.output_filename = stats.output_filename.substr(8); - output(stats.output_filename); + try{ + output("../"+stats.output_filename); + }catch(std::runtime_error & e){ + stats.output_filename = stats.output_filename.substr(8); + output(stats.output_filename); + } } if(!quiet_output) std::cout << "Map seed was " << seed << std::endl << "Opening a file at " << stats.output_filename << std::endl; diff --git a/src/main.cpp b/src/main.cpp index 8ec037b..2bb2760 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -47,6 +47,7 @@ int main(int argc, char ** argv) { TCLAP::ValueArg nPlayersArg("n", "nplayers", "Create a map that will accommodate n players [SINGLE PLAYER MODE ONLY].", false, 1, "{1,2,3,4,5,6}", cmd); TCLAP::ValueArg< std::pair > dimensionArgs("d", "dimensions", "The dimensions of the map.", false, { 0, 0 }, "a string containing two space-seprated positive integers", cmd); TCLAP::ValueArg seedArg("s", "seed", "The seed for the map generator.", false, 0, "positive integer", cmd); + TCLAP::ValueArg max_game_args("g", "maxgame", "The max number of games.", false, 0, "integer", cmd); TCLAP::ValueArg custom_max_strength_args("z", "maxstrength", "The max strength.", false, 0, "positive integer", cmd); TCLAP::ValueArg customMaxTurnNumberArg("x", "maxturn", "The number of turns.", false, 0, "positive integer", cmd); //Remaining Args, be they start commands and/or override names. Description only includes start commands since it will only be seen on local testing. @@ -136,8 +137,9 @@ int main(int argc, char ** argv) { std::cout << std::endl << "A map can only accommodate between 1 and 6 players." << std::endl << std::endl; exit(1); } - - while(true){ + int ng = max_game_args.getValue(); + bool infinite_loop = ng==-1; + while(infinite_loop || ng-- > 0){ seed = (std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count() % 4294967295); my_game = new Halite(mapWidth, mapHeight, seed, n_players_for_map_creation, networking, ignore_timeout, custom_max_strength_args.getValue()); @@ -145,7 +147,7 @@ int main(int argc, char ** argv) { if(names != NULL) delete names; - //delete my_game; + } if(names != NULL) delete names; diff --git a/src/networking/Networking.cpp b/src/networking/Networking.cpp index cd8239b..a862fea 100644 --- a/src/networking/Networking.cpp +++ b/src/networking/Networking.cpp @@ -339,7 +339,6 @@ int Networking::handleInitNetworking(unsigned char playerTag, const hlt::Map & m std::chrono::high_resolution_clock::time_point initialTime = std::chrono::high_resolution_clock::now(); response = getString(playerTag, ALLOTTED_MILLIS); unsigned int millisTaken = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - initialTime).count(); - player_logs[playerTag - 1] += response + "\n --- Bot used " + std::to_string(millisTaken) + " milliseconds ---"; *playerName = response.substr(0, 30); diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..007eb95 --- /dev/null +++ b/tests/README.md @@ -0,0 +1 @@ +# Tests diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..7d80e79 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +""" +Contributors: + - Louis Rémus +""" diff --git a/tests/dijkstra_speed_test.py b/tests/dijkstra_speed_test.py new file mode 100644 index 0000000..eed9fde --- /dev/null +++ b/tests/dijkstra_speed_test.py @@ -0,0 +1,47 @@ +"""Testing the speed of 2 dijkstra alternative""" +import numpy as np +import pytest + +from public.util.dijkstra import build_graph_from_state, dijkstra +from tests.util import game_states_from_file + + +def dijkstra_naive(game_states): + for game_state in game_states: + g = build_graph_from_state(game_state[0]) + dist_dict, _ = dijkstra(g.g, 0) + + dist = np.zeros_like(game_state[0]) + for key, value in dist_dict.items(): + dist[key] = value + + +def dijkstra_update(game_states): + g = build_graph_from_state(game_states[0][0]) + for i in range(1, len(game_states)): + g.update(game_states[i][0], game_states[i - 1][0]) + dist_dict, _ = dijkstra(g.g, 0) + + dist = np.zeros_like(game_states[i][0]) + for key, value in dist_dict.items(): + dist[key] = value + + +@pytest.mark.benchmark(group="dijkstra") +def test_dijkstra_naive_speed(benchmark): + """ + Benchmark the time of dijsktra + """ + game_states, _ = game_states_from_file() + benchmark(dijkstra_naive, game_states=game_states) + assert True + + +@pytest.mark.benchmark(group="dijkstra") +def test_dijkstra_update_speed(benchmark): + """ + Benchmark the time of dijsktra + """ + game_states, _ = game_states_from_file() + benchmark(dijkstra_update, game_states=game_states) + assert True diff --git a/tests/dijkstra_test.py b/tests/dijkstra_test.py new file mode 100644 index 0000000..9c49016 --- /dev/null +++ b/tests/dijkstra_test.py @@ -0,0 +1,37 @@ +""" +Tests the dijkstra function +""" +import unittest + +import numpy as np +from public.util.dijkstra import build_graph_from_state, dijkstra + + +class TestDijkstra(unittest.TestCase): + """ + Tests the dijkstra algorithm + """ + + def test_dijkstra(self): + """ + Test the dijkstra algorithm + """ + state = np.array([[0, 0, 0, 0, 1, 1], + [0, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1]]) + + print(state) + g = build_graph_from_state(state) + dist_dict, _ = dijkstra(g.g, 0) + + dist = np.zeros_like(state) + for key, value in dist_dict.items(): + dist[key] = value + print(dist) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/path_test.py b/tests/path_test.py new file mode 100644 index 0000000..d58b0c6 --- /dev/null +++ b/tests/path_test.py @@ -0,0 +1,24 @@ +"""Test the path_to function""" +import unittest + +from public.util.path import path_to + +class PathTo(unittest.TestCase): + """ + Tests the path_to function + """ + + def test_path_to(self): + """ + Test the path_to function + """ + self.assertTrue(path_to((0, 0), (0, 4), 5, 5) == (0, -1)) + self.assertTrue(path_to((4, 4), (0, 0), 5, 5) == (1, 1)) + self.assertTrue(path_to((0, 0), (4, 4), 5, 5) == (-1, -1)) + self.assertTrue(path_to((0, 0), (4, 4), 5, 10) == (-1, 4)) + self.assertTrue(path_to((0, 0), (4, 4), 6, 10) == (-2, 4)) + self.assertTrue(path_to((0, 0), (4, 4), 7, 10) == (-3, 4)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/reward_test.py b/tests/reward_test.py new file mode 100644 index 0000000..268c256 --- /dev/null +++ b/tests/reward_test.py @@ -0,0 +1,62 @@ +""" +Tests the reward function +""" +import unittest + +import numpy as np + +from public.state.state import State1 +from tests.util import game_states_from_url +from train.experience import ExperienceVanilla +from train.reward.reward import Reward +from train.reward.util import discount_rewards +from train.worker import Worker + + +class TestReward(unittest.TestCase): + """ + Tests the reward function + """ + + def test_length_discount_rewards(self): + """ + Test the length of the discount reward + """ + self.assertTrue(len(discount_rewards(np.array([1]))) == 1) + self.assertTrue(len(discount_rewards(np.array([1, 3]))) == 2) + + def test_reward(self): + """ + Test the length of the discount reward + """ + game_url = 'https://s3.eu-central-1.amazonaws.com/halite-python-rl/hlt-games/trained-bot.hlt' + game_states, moves = game_states_from_url(game_url) + + s = State1() + r = Reward(s) + raw_rewards = r.raw_rewards_function(game_states) + self.assertTrue(len(raw_rewards) == len(game_states) - 1) + + all_states, all_moves, all_rewards = r.all_rewards_function(game_states, moves) + self.assertTrue(len(all_states) >= len(game_states) - 1) + self.assertTrue(len(all_moves) >= len(moves)) + self.assertTrue(len(all_rewards) == len(all_moves) and len(all_states) == len(all_moves)) + + experience = ExperienceVanilla(s.local_size, name='') + experience.add_episode(game_states, all_states, all_moves, all_rewards) + experience.add_episode(game_states, all_states, all_moves, all_rewards) + self.assertTrue(len(experience.moves) == 2 * len(all_moves)) + batch_states, batch_moves, batch_rewards = experience.batch() + self.assertTrue(len(batch_rewards) == len(batch_moves) and len(batch_states) == len(batch_moves)) + + def test_worker(self): + """ + Test if the worker port initiate and terminate with good port + """ + worker = Worker(2000, 2, None) + self.assertTrue(worker.port == 2002) + worker.p.terminate() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/state_speed_test.py b/tests/state_speed_test.py new file mode 100644 index 0000000..18d8a33 --- /dev/null +++ b/tests/state_speed_test.py @@ -0,0 +1,73 @@ +"""Testing the speed of building different game states""" +import numpy as np +import pytest + +from public.state.state import State2, State1 +from tests.util import game_states_from_file + + +def build_game_state(game_states, state): + for game_state in game_states: + for (y, x), k in np.ndenumerate(game_state[0]): + if k == 1: + state.get_local(game_state, x, y) + + +@pytest.mark.benchmark(group="state") +def test_state1_scope1(benchmark): + """ + Benchmark the time of dijsktra + """ + game_states, _ = game_states_from_file() + benchmark(build_game_state, game_states=game_states, state=State1(scope=1)) + assert True + + +@pytest.mark.benchmark(group="state") +def test_state1_scope2(benchmark): + """ + Benchmark the time of dijsktra + """ + game_states, _ = game_states_from_file() + benchmark(build_game_state, game_states=game_states, state=State1(scope=2)) + assert True + + +@pytest.mark.benchmark(group="state") +def test_state1_scope3(benchmark): + """ + Benchmark the time of dijsktra + """ + game_states, _ = game_states_from_file() + benchmark(build_game_state, game_states=game_states, state=State1(scope=3)) + assert True + + +@pytest.mark.benchmark(group="state") +def test_state2_scope2(benchmark): + """ + Benchmark the time of dijsktra + """ + game_states, _ = game_states_from_file() + benchmark(build_game_state, game_states=game_states, state=State2(scope=2)) + assert True + + +@pytest.mark.benchmark(group="state") +def test_state2_scope3(benchmark): + """ + Benchmark the time of dijsktra + """ + game_states, _ = game_states_from_file() + benchmark(build_game_state, game_states=game_states, state=State2(scope=3)) + assert True + + +@pytest.mark.benchmark(group="state") +def test_state2_scope4(benchmark): + """ + Benchmark the time of dijsktra + """ + game_states, _ = game_states_from_file() + benchmark(build_game_state, game_states=game_states, state=State2(scope=4)) + assert True diff --git a/tests/tensorflow_call_speed_test.py b/tests/tensorflow_call_speed_test.py new file mode 100644 index 0000000..2c7227d --- /dev/null +++ b/tests/tensorflow_call_speed_test.py @@ -0,0 +1,50 @@ +"""Testing the speed of 2 tensorflow alternatives""" +import numpy as np +import pytest + +from public.models.strategy.TrainedStrategy import TrainedStrategy +from tests.util import game_states_from_file + + +def tensorflow_naive(game_states, sess, agent, state): + for game_state in game_states: + for y in range(len(game_state[0])): + for x in range(len(game_state[0][0])): + if game_state[0][y][x] == 1: + game_state_n = state.get_local_and_normalize(game_state, x, y).reshape(1, -1) + sess.run(agent.policy, feed_dict={agent.state_in: game_state_n}) + + +def tensorflow_combined(game_states, sess, agent, state): + for game_state in game_states: + all_game_state_n = np.array([]).reshape(0, state.local_size) + for y in range(len(game_state[0])): + for x in range(len(game_state[0][0])): + if game_state[0][y][x] == 1: + game_state_n = state.get_local_and_normalize(game_state, x, y).reshape(1, -1) + all_game_state_n = np.concatenate((all_game_state_n, game_state_n), axis=0) + sess.run(agent.policy, feed_dict={agent.state_in: all_game_state_n}) + + +@pytest.mark.benchmark(group="tf") +def test_tensorflow_naive_speed(benchmark): + """ + Benchmark the time of dijsktra + """ + bot = TrainedStrategy() + bot.init_session() + game_states, _ = game_states_from_file() + benchmark(tensorflow_naive, game_states=game_states, sess=bot.sess, agent=bot.agent1, state=bot.state) + assert True + + +@pytest.mark.benchmark(group="tf") +def test_tensorflow_combined_speed(benchmark): + """ + Benchmark the time of dijsktra + """ + bot = TrainedStrategy() + bot.init_session() + game_states, _ = game_states_from_file() + benchmark(tensorflow_combined, game_states=game_states, sess=bot.sess, agent=bot.agent1, state=bot.state) + assert True diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 0000000..7e8ca09 --- /dev/null +++ b/tests/util.py @@ -0,0 +1,39 @@ +"""Importing the game from aws""" +import json +import os +import urllib.request + +import numpy as np + + +def game_states_from_url(game_url): + """ + We host known games on aws server and we run the tests according to these games, from which we know the output + :param game_url: The url of the game on the server (string). + :return: + """ + return text_to_game(urllib.request.urlopen(game_url).readline().decode("utf-8")) + + +def text_to_game(text): + """ + Transform the text from the *.hlt files into game_states + """ + game = json.loads(text) + + owner_frames = np.array(game["frames"])[:, :, :, 0][:, np.newaxis, :, :] + strength_frames = np.array(game["frames"])[:, :, :, 1][:, np.newaxis, :, :] + production_frames = np.repeat(np.array(game["productions"])[np.newaxis, np.newaxis, :, :], len(owner_frames), + axis=0) + moves = np.array(game['moves']) + + game_states = np.concatenate(([owner_frames, strength_frames, production_frames]), axis=1) + return game_states, moves + + +def game_states_from_file(): + path_to_hlt = os.path.abspath(os.path.join(os.path.dirname(__file__), '../visualize/hlt/')) # 'visualize/hlt/' + + hlt_files = [hlt_file for hlt_file in os.listdir(path_to_hlt) if hlt_file not in ['.DS_Store', 'README.md']] + filepath = hlt_files[0] + return text_to_game(open(path_to_hlt + '/' + filepath).read()) diff --git a/train/__init__.py b/train/__init__.py index e69de29..7d80e79 100644 --- a/train/__init__.py +++ b/train/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +""" +Contributors: + - Louis Rémus +""" diff --git a/train/experience.py b/train/experience.py index 9b0c485..8f83d25 100644 --- a/train/experience.py +++ b/train/experience.py @@ -1,38 +1,71 @@ +""" +Experience class definition +""" +import os + import numpy as np -from train.reward import allRewards, rawRewards +from train.reward.util import production_increments_function class Experience: - def __init__(self): + """ + Experience class to store moves, rewards and metric values + """ + + def __init__(self, max_size=10000, min_size=5000): + self.max_size = max_size + self.min_size = min_size self.moves = np.array([]) self.rewards = np.array([]) self.metric = np.array([]) - def add_episode(self, game_states, moves): - production_increments = np.sum(np.sum(rawRewards(game_states), axis=2), axis=1) - self.metric = np.append(self.metric, production_increments.dot(np.linspace(2.0, 1.0, num=len(game_states) - 1))) + def add_episode(self, game_states, all_states, all_moves, all_rewards): + pass def batch(self, size): pass + def compute_metric(self, game_states): + production_increments = production_increments_function(game_states) + self.metric = np.append(self.metric, production_increments.dot(np.linspace(2.0, 1.0, num=len(game_states) - 1))) + def save_metric(self, name): np.save(name, self.metric) class ExperienceVanilla(Experience): - def __init__(self): + """ + Stores states in addition to the inherited attributes of Experience + """ + + def __init__(self, s_size, name): super(ExperienceVanilla, self).__init__() - self.states = np.array([]).reshape(0, 27) + self.s_size = s_size + self.states = np.array([]).reshape(0, s_size) + try: + self.metric = np.load(os.path.abspath( + os.path.join(os.path.dirname(__file__), '..')) + + '/public/models/variables/' + name + '/' + + name + '.npy') + except FileNotFoundError: + print("Metric file not found") + self.metric = np.array([]) - def add_episode(self, game_states, moves): - super(ExperienceVanilla, self).add_episode(game_states, moves) - all_states, all_moves, all_rewards = allRewards(game_states, moves) + def add_episode(self, game_states, all_states, all_moves, all_rewards): + self.compute_metric(game_states) - self.states = np.concatenate((self.states, all_states.reshape(-1, 27)), axis=0) + self.states = np.concatenate((self.states, all_states.reshape(-1, self.s_size)), axis=0) self.moves = np.concatenate((self.moves, all_moves)) self.rewards = np.concatenate((self.rewards, all_rewards)) + if len(self.states) >= self.max_size: + self.resize() + + def resize(self): + self.states = self.states[self.min_size:] + self.moves = self.moves[self.min_size:] + self.rewards = self.rewards[self.min_size:] def batch(self, size=128): indices = np.random.randint(len(self.states), size=min(int(len(self.states) / 2), size)) diff --git a/train/main.py b/train/main.py index 3b7e621..b63658a 100644 --- a/train/main.py +++ b/train/main.py @@ -1,61 +1,54 @@ -import multiprocessing +"""This main.py file runs the training.""" +import os import sys import threading import tensorflow as tf +from tensorflow.python.framework.errors_impl import InvalidArgumentError -from public.models.agent.vanillaAgent import VanillaAgent -from train.experience import ExperienceVanilla -from train.worker import Worker +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +try: + from public.models.strategy.TrainedStrategy import TrainedStrategy + from train.worker import Worker +except: + raise port = int(sys.argv[1]) if len(sys.argv) > 1 else 2000 -tf.reset_default_graph() # Clear the Tensorflow graph. +strategy = TrainedStrategy(tf_scope='global') -with tf.device("/cpu:0"): - lr = 1e-3; - s_size = 9 * 3; - a_size = 5; - h_size = 50 +num_workers = 1 +n_simultations = 10 - with tf.variable_scope('global'): - master_experience = ExperienceVanilla() - master_agent = VanillaAgent(master_experience, lr, s_size, a_size, h_size) - - num_workers = 1 # multiprocessing.cpu_count()# (2) Maybe set max number of workers / number of available CPU threads - n_simultations = 15 - - workers = [] - if num_workers > 1: - for i in range(num_workers): - with tf.variable_scope('worker_' + str(i)): - experience = ExperienceVanilla() - agent = VanillaAgent(experience, lr, s_size, a_size, h_size) - workers.append(Worker(port, i, agent)) - else: - workers.append(Worker(port, 0, master_agent)) - # We need only to save the global - global_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='global') - saver = tf.train.Saver(global_variables) - init = tf.global_variables_initializer() +workers = [] +if num_workers > 1: + for i in range(num_workers): + workers.append(Worker(port, i, TrainedStrategy(tf_scope='worker_' + str(i)))) +else: + workers.append(Worker(port, 0, strategy)) +# We need only to save the global +global_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='global') +saver = tf.train.Saver(global_variables) +init = tf.global_variables_initializer() # Launch the tensorflow graph with tf.Session() as sess: sess.run(init) try: - saver.restore(sess, './public/models/' + master_agent.name) - except Exception: + saver.restore(sess, os.path.abspath( + os.path.dirname(__file__)) + '/../public/models/variables/' + strategy.name + '/' + strategy.name) + except InvalidArgumentError: print("Model not found - initiating new one") coord = tf.train.Coordinator() worker_threads = [] - print("I'm the main thread running on CPU #%s" % multiprocessing.current_process().name) + print("I'm the main thread running on CPU") - if (num_workers == 1): - workers[0].work(sess, coord, saver, n_simultations) + if num_workers == 1: + workers[0].work(sess, saver, n_simultations) else: for worker in workers: - worker_work = lambda: worker.work(sess, coord, saver, n_simultations) + worker_work = lambda worker=worker: worker.work(sess, saver, n_simultations) t = threading.Thread(target=(worker_work)) # Process instead of threading.Thread multiprocessing.Process t.start() worker_threads.append(t) diff --git a/train/reward.py b/train/reward.py deleted file mode 100644 index 99a1ad1..0000000 --- a/train/reward.py +++ /dev/null @@ -1,108 +0,0 @@ -import numpy as np - -gamma = 0.8 -from public.hlt import NORTH, EAST, SOUTH, WEST, Move - -STRENGTH_SCALE = 255 -PRODUCTION_SCALE = 10 - - -def getGameState(game_map, myID): - game_state = np.reshape( - [[(square.owner == myID) + 0, square.strength, square.production] for square in game_map], - [game_map.width, game_map.height, 3]) - return np.swapaxes(np.swapaxes(game_state, 2, 0), 1, 2) * ( - 1 / np.array([1, STRENGTH_SCALE, PRODUCTION_SCALE])[:, np.newaxis, np.newaxis]) - - -def getGameProd(game_state): - return PRODUCTION_SCALE * np.sum(game_state[0] * game_state[2]) - - -def getStrength(game_state): - return game_state[1][1][ - 1] * STRENGTH_SCALE # np.sum([square.strength for square in game_map if square.owner == myID]) - - -def discount_rewards(r): - """ take 1D float array of rewards and compute discounted reward """ - discounted_r = np.zeros_like(r, dtype=np.float64) - running_add = 0 - for t in reversed(range(0, r.size)): - running_add = running_add * gamma + r[t] - discounted_r[t] = running_add - return discounted_r - - -def localStateFromGlobal(game_state, x, y, size=1): - # TODO: for now we still take a square, but a more complex shape could be better. - return np.take(np.take(game_state, range(y - size, y + size + 1), axis=1, mode='wrap'), - range(x - size, x + size + 1), axis=2, mode='wrap') - - -def rawRewards(game_states): - return np.array([game_states[i + 1][0] * game_states[i + 1][2] - game_states[i][0] * game_states[i][2] for i in - range(len(game_states) - 1)]) - - -def discountedReward(next_reward, move_before, discount_factor=1.0): - reward = np.zeros_like(next_reward) - for y in range(len(reward)): - for x in range(len(reward[0])): - d = move_before[y][x] - if d != -1: - dy = (-1 if d == NORTH else 1) if (d == SOUTH or d == NORTH) else 0 - dx = (-1 if d == WEST else 1) if (d == WEST or d == EAST) else 0 - reward[y][x] = discount_factor * np.take(np.take(next_reward, x + dx, axis=1, mode='wrap'), y + dy, - axis=0, mode='wrap') - return reward - - -def discountedRewards(raw_rewards, moves): - discounted_rewards = np.zeros_like(raw_rewards, dtype=np.float64) - running_reward = np.zeros_like(raw_rewards[0]) - for t in reversed(range(0, len(raw_rewards))): - running_reward = discountedReward(running_reward, moves[t], discount_factor=0.8) + discountedReward( - raw_rewards[t], moves[t]) - discounted_rewards[t] = running_reward - return discounted_rewards - - -def individualStatesAndRewards(game_state, move, discounted_reward): - states = [] - moves = [] - rewards = [] - for y in range(len(game_state[0])): - for x in range(len(game_state[0][0])): - if (game_state[0][y][x] == 1): - states += [localStateFromGlobal(game_state, x, y)] - moves += [move[y][x]] - rewards += [discounted_reward[y][x]] - return states, moves, rewards - - -def allIndividualStatesAndRewards(game_states, moves, discounted_rewards): - all_states = [] - all_moves = [] - all_rewards = [] - for game_state, move, discounted_reward in zip(game_states, moves, discounted_rewards): - states_, moves_, rewards_ = individualStatesAndRewards(game_state, move, discounted_reward) - all_states += states_ - all_moves += moves_ - all_rewards += rewards_ - return np.array(all_states), np.array(all_moves), np.array(all_rewards) - - -def allRewards(game_states, moves): - # game_states n+1, moves n - discounted_rewards = discountedRewards(rawRewards(game_states), moves) - return allIndividualStatesAndRewards(game_states[:-1], moves, discounted_rewards) - - -def formatMoves(game_map, moves): - moves_to_send = [] - for y in range(len(game_map.contents)): - for x in range(len(game_map.contents[0])): - if moves[y][x] != -1: - moves_to_send += [Move(game_map.contents[y][x], moves[y][x])] - return moves_to_send diff --git a/public/__init__.py b/train/reward/__init__.py similarity index 100% rename from public/__init__.py rename to train/reward/__init__.py diff --git a/train/reward/reward.py b/train/reward/reward.py new file mode 100644 index 0000000..ee09c7f --- /dev/null +++ b/train/reward/reward.py @@ -0,0 +1,76 @@ +"""The reward.py file to compute the reward""" +import numpy as np + +from public.hlt import NORTH, EAST, SOUTH, WEST +from train.reward.util import get_prod + + +class Reward: + """The reward class""" + + def __init__(self, state, discount_factor=0.6): + self.discount_factor = discount_factor + self.state = state + + def raw_rewards_function(self, game_states): + return np.array( + [0.1 * np.power(get_prod(game_states[i + 1]) - get_prod(game_states[i]), 2) + for i in range(len(game_states) - 1)]) + + def individual_states_and_rewards(self, game_state, move, discounted_reward): + """Self-explanatory""" + states = [] + moves = [] + rewards = [] + + for (y, x), k in np.ndenumerate(game_state[0]): + if k == 1 and move[y][x] != -1: + states += [self.state.get_local_and_normalize(game_state, x, y)] + moves += [move[y][x]] + rewards += [discounted_reward[y][x]] + return states, moves, rewards + + def discounted_reward_function(self, next_reward, move_before, strength_before, discount_factor=1.0): + """Self-explanatory""" + reward = np.zeros_like(next_reward) + + def take_value(matrix, x, y): + return np.take(np.take(matrix, x, axis=1, mode='wrap'), y, axis=0, mode='wrap') + + for (y, x), d in np.ndenumerate(move_before): + if d != -1: + dy = (-1 if d == NORTH else 1) if (d == SOUTH or d == NORTH) else 0 + dx = (-1 if d == WEST else 1) if (d == WEST or d == EAST) else 0 + reward[y][x] = discount_factor * take_value(next_reward, x + dx, y + dy) \ + if strength_before[y][x] >= take_value(strength_before, x + dx, y + dy) \ + else 0 + return reward + + def discounted_rewards_function(self, game_states, moves): + """Self-explanatory""" + raw_rewards = self.raw_rewards_function(game_states) + discounted_rewards = np.zeros_like(raw_rewards, dtype=np.float64) + running_reward = np.zeros_like(raw_rewards[0], dtype=np.float64) + for t, (raw_reward, move, game_state) in reversed(list(enumerate(zip(raw_rewards, moves, game_states)))): + running_reward = self.discounted_reward_function(running_reward, move, game_state[1], + discount_factor=self.discount_factor) + \ + self.discounted_reward_function(raw_reward, move, game_state[1]) + discounted_rewards[t] = running_reward - 0.01 + return discounted_rewards + + def all_individual_states_and_rewards(self, game_states, moves, discounted_rewards): + """Self-explanatory""" + all_states = [] + all_moves = [] + all_rewards = [] + for game_state, move, discounted_reward in zip(game_states, moves, discounted_rewards): + states_, moves_, rewards_ = self.individual_states_and_rewards( + game_state, move, discounted_reward) + all_states += states_ + all_moves += moves_ + all_rewards += rewards_ + return np.array(all_states), np.array(all_moves), np.array(all_rewards) + + def all_rewards_function(self, game_states, moves): + discounted_rewards = self.discounted_rewards_function(game_states, moves) + return self.all_individual_states_and_rewards(game_states[:-1], moves, discounted_rewards) diff --git a/train/reward/util.py b/train/reward/util.py new file mode 100644 index 0000000..de9046a --- /dev/null +++ b/train/reward/util.py @@ -0,0 +1,37 @@ +"""Useful for computing the rewards""" +import numpy as np + + +def discount_rewards(r, gamma=0.8): + """ take 1D float array of rewards and compute discounted reward """ + discounted_r = np.zeros_like(r, dtype=np.float64) + running_add = 0 + for t in reversed(range(0, r.size)): + running_add = running_add * gamma + r[t] + discounted_r[t] = running_add + return discounted_r + + +def get_total_prod(game_state): + return np.sum(game_state[0] * game_state[2]) + + +def get_prod(game_state): + return game_state[0] * game_state[2] + + +def get_total_strength(game_state): + return np.sum(game_state[0] * game_state[1]) + + +def get_strength(game_state): + return game_state[0] * game_state[1] + + +def get_total_number(game_state): + return np.sum(game_state[0]) + + +def production_increments_function(game_states): + return np.array([get_total_prod(game_states[i + 1]) - get_total_prod(game_states[i]) + for i in range(len(game_states) - 1)]) diff --git a/train/worker.py b/train/worker.py index edaafec..0d5ef73 100644 --- a/train/worker.py +++ b/train/worker.py @@ -1,11 +1,11 @@ +"""The worker class for training and parallel operations""" import multiprocessing -import subprocess import time import tensorflow as tf from networking.hlt_networking import HLT -from train.reward import getGameState, formatMoves +from networking.start_game import start_game def update_target_graph(from_scope, to_scope): @@ -18,48 +18,60 @@ def update_target_graph(from_scope, to_scope): return op_holder -class Worker(): - def __init__(self, port, number, agent): +class Worker: + """ + The Worker class for training. Each worker has an individual port, number, and agent. + Each of them work with the global session, and use the global saver. + """ + + def __init__(self, port, number, strategy): self.name = 'worker_' + str(number) self.number = number self.port = port + number def worker(): - subprocess.call(['./networking/runGameDebugConfig.sh', str(self.port)]) # runSimulation + start_game(self.port, quiet=True, max_game=-1, width=25, height=25, max_turn=50, + max_strength=60) # Infinite games self.p = multiprocessing.Process(target=worker) self.p.start() time.sleep(1) - self.hlt = HLT(port=self.port) - self.agent = agent + self.hlt = HLT(port=self.port) # Launching the pipe operation + self.strategy = strategy self.update_local_ops = update_target_graph('global', self.name) - def work(self, sess, coord, saver, n_simultations): - print("Starting worker " + str(self.number)) - + def work(self, sess, saver, n_simultations): + """ + Using the pipe operation launched at initialization, + the worker works `n_simultations` games to train the + agent + :param sess: The global session + :param n_simultations: Number of max simulations to run. + Afterwards the process is stopped. + :return: + """ + self.strategy.init_session(sess, saver) with sess.as_default(), sess.graph.as_default(): for i in range(n_simultations): # while not coord.should_stop(): - if (i % 10 == 1 and self.number == 0): + if i % 10 == 1 and self.number == 0: print("Simulation: " + str(i)) # self.port) sess.run(self.update_local_ops) # GET THE WORK DONE FROM OTHER - myID, game_map = self.hlt.get_init() + print("This is one simulation") + my_id, game_map = self.hlt.get_init() self.hlt.send_init("MyPythonBot") - - moves = [] - game_states = [] - while (self.hlt.get_string() == 'Get map and play!'): + self.strategy.set_id(my_id) + while True: + get_map_and_play = self.hlt.get_string() + if get_map_and_play != 'Get map and play!': + print(get_map_and_play) + break game_map.get_frame(self.hlt.get_string()) - game_states += [getGameState(game_map, myID)] - moves += [self.agent.choose_actions(sess, game_states[-1])] - self.hlt.send_frame(formatMoves(game_map, moves[-1])) - - self.agent.experience.add_episode(game_states, moves) - self.agent.update_agent(sess) + self.hlt.send_frame(self.strategy.compute_moves(game_map, train=True)) + self.strategy.add_episode() + self.strategy.update_agent() if self.number == 0: - saver.save(sess, './public/models/' + self.agent.name) - self.agent.experience.save_metric('./public/models/' + self.agent.name) - + self.strategy.save() self.p.terminate() diff --git a/visualize/hlt/README.md b/visualize/hlt/README.md new file mode 100644 index 0000000..b0c064f --- /dev/null +++ b/visualize/hlt/README.md @@ -0,0 +1,3 @@ +# HLT files + +Here should automatically go all the *.hlt files. \ No newline at end of file diff --git a/visualize/static/localVisualizer.js b/visualize/static/localVisualizer.js new file mode 100755 index 0000000..534aec4 --- /dev/null +++ b/visualize/static/localVisualizer.js @@ -0,0 +1,65 @@ +$(function () { + var $dropZone = $("html"); + var $filePicker = $("#filePicker"); + function handleFiles(files) { + // only use the first file. + file = files[0]; + console.log(file) + var reader = new FileReader(); + + reader.onload = (function(filename) { // finished reading file data. + return function(e2) { + $("#displayArea").empty(); + var fsHeight = $("#fileSelect").outerHeight(); + showGame(textToGame(e2.target.result, filename), $("#displayArea"), null, -fsHeight, true, false, true); + }; + })(file.name); + reader.readAsText(file); // start reading the file data. + } + + $dropZone.on('dragover', function(e) { + e.stopPropagation(); + e.preventDefault(); + }); + $dropZone.on('drop', function(e) { + e.stopPropagation(); + e.preventDefault(); + var files = e.originalEvent.dataTransfer.files; // Array of all files + handleFiles(files) + }); + $filePicker.on('change', function(e) { + var files = e.target.files + handleFiles(files) + }); + + $("li").on("click",function() { + $.ajax({ + type: "GET", + url: this.id, + success: function(text) { + $("#displayArea").empty(); + var fsHeight = $("#fileSelect").outerHeight(); + showGame(textToGame(text, "OK"), $("#displayArea"), null, -fsHeight, true, false, true); + // `text` is the file text + }, + error: function() { + // An error occurred + } + }); + }); + if($("li").length>=1){ + $.ajax({ + type: "GET", + url: $("li")[$("li").length-1].id, + success: function(text) { + $("#displayArea").empty(); + var fsHeight = $("#fileSelect").outerHeight(); + showGame(textToGame(text, "OK"), $("#displayArea"), null, -fsHeight, true, false, true); + // `text` is the file text + }, + error: function() { + // An error occurred + } + }); + } +}) diff --git a/visualize/static/parsereplay.js b/visualize/static/parsereplay.js new file mode 100755 index 0000000..9a114a7 --- /dev/null +++ b/visualize/static/parsereplay.js @@ -0,0 +1,380 @@ +function processFrame(game, frameNum) { + var checkSim = false; + var gameMap = game.frames[frameNum]; + if(checkSim) { + gameMap = _.cloneDeep(game.frames[frameNum]); + } + var moves = game.moves[frameNum]; + var productions = game.productions; + var width = game.width; + var height = game.height; + var numPlayers = game.num_players; + + var STILL = 0; + var NORTH = 1; + var EAST = 2; + var SOUTH = 3; + var WEST = 4; + + var pieces = []; + var stats = []; + + var p, q, y, x; + + function getLocation(loc, direction) { + if (direction === STILL) { + // nothing + } else if (direction === NORTH) { + loc.y -= 1; + } else if (direction === EAST) { + loc.x += 1; + } else if (direction === SOUTH) { + loc.y += 1; + } else if (direction === WEST) { + loc.x -= 1; + } + + if (loc.x < 0) { + loc.x = width - 1; + } else { + loc.x %= width; + } + + if (loc.y < 0) { + loc.y = height - 1; + } else { + loc.y %= height; + } + } + + for (p = 0; p < numPlayers; p++) { + pieces[p] = []; + stats[p] = { + actualProduction: 0, + playerDamageDealt: 0, + environmentDamageDealt: 0, + damageTaken: 0, + capLosses: 0, + overkillDamage: 0, + }; + for (y = 0; y < height; y++) { + pieces[p][y] = []; + } + } + + for (y = 0; y < height; y++) { + for (x = 0; x < width; x++) { + var direction = moves[y][x]; + var cell = gameMap[y][x]; + var player = gameMap[y][x].owner - 1; + var production = productions[y][x]; + + if (gameMap[y][x].owner == 0) continue + + if (direction === STILL) { + cell = { owner: gameMap[y][x].owner, strength: gameMap[y][x].strength }; + if (cell.strength + production <= 255) { + stats[player].actualProduction += production; + cell.strength += production; + } else { + stats[player].actualProduction += cell.strength - 255; + stats[player].capLosses += cell.strength + production - 255; + cell.strength = 255; + } + } + + var newLoc = { x: x, y: y }; + getLocation(newLoc, direction); + if (!_.isUndefined(pieces[player][newLoc.y][newLoc.x])) { + if (pieces[player][newLoc.y][newLoc.x] + cell.strength <= 255) { + pieces[player][newLoc.y][newLoc.x] += cell.strength; + } else { + stats[player].capLosses += pieces[player][newLoc.y][newLoc.x] + cell.strength - 255; + pieces[player][newLoc.y][newLoc.x] = 255; + } + } else { + pieces[player][newLoc.y][newLoc.x] = cell.strength; + } + + // add in a new piece with a strength of 0 if necessary + if (_.isUndefined(pieces[player][y][x])) { + pieces[player][y][x] = 0; + } + + // erase from the game map so that the player can't make another move with the same piece + // On second thought, trust that the original game took care of that. + if(checkSim) { + gameMap[y][x] = { owner: 0, strength: 0 }; + } + } + } + + var toInjure = []; + var injureMap = []; + + for (p = 0; p < numPlayers; p++) { + toInjure[p] = []; + for (y = 0; y < height; y++) { + toInjure[p][y] = []; + } + } + + for (y = 0; y < height; y++) { + injureMap[y] = []; + for (x = 0; x < width; x++) { + injureMap[y][x] = 0; + } + } + + for (y = 0; y < height; y++) { + for (x = 0; x < width; x++) { + for (p = 0; p < numPlayers; p++) { + // if player p has a piece at these coords + if (!_.isUndefined(pieces[p][y][x])) { + var damageDone = 0; + // look for other players with pieces here + for (q = 0; q < numPlayers; q++) { + // exclude the same player + if (p !== q) { + for (var dir = STILL; dir <= WEST; dir++) { + // check STILL square + var loc = { x: x, y: y }; + getLocation(loc, dir); + + // if the other player has a piece here + if (!_.isUndefined(pieces[q][loc.y][loc.x])) { + // add player p's damage + if (!_.isUndefined(toInjure[q][loc.y][loc.x])) { + toInjure[q][loc.y][loc.x] += pieces[p][y][x]; + stats[p].playerDamageDealt += pieces[p][y][x]; + damageDone += Math.min(pieces[p][y][x], pieces[q][loc.y][loc.x]); + } else { + toInjure[q][loc.y][loc.x] = pieces[p][y][x]; + stats[p].playerDamageDealt += pieces[p][y][x]; + damageDone += Math.min(pieces[p][y][x], pieces[q][loc.y][loc.x]); + } + } + } + } + } + + // if the environment can do damage back + if (gameMap[y][x].owner == 0 && gameMap[y][x].strength > 0) { + if (!_.isUndefined(toInjure[p][y][x])) { + toInjure[p][y][x] += gameMap[y][x].strength; + } else { + toInjure[p][y][x] = gameMap[y][x].strength; + } + // and apply damage to the environment + injureMap[y][x] += pieces[p][y][x]; + damageDone += Math.min(pieces[p][y][x], gameMap[y][x].strength); + stats[p].environmentDamageDealt += Math.min(pieces[p][y][x], gameMap[y][x].strength); + } + + if (damageDone > pieces[p][y][x]) { + stats[p].overkillDamage += damageDone - pieces[p][y][x]; + } + } + } + } + } + + // injure and/or delete pieces. Note >= rather than > indicates that pieces with a strength of 0 are killed. + for (p = 0; p < numPlayers; p++) { + for (y = 0; y < height; y++) { + for (x = 0; x < width; x++) { + if (!_.isUndefined(toInjure[p][y][x])) { + if (toInjure[p][y][x] >= pieces[p][y][x]) { + stats[p].damageTaken += pieces[p][y][x]; + pieces[p][y][x] = undefined; + } else { + stats[p].damageTaken += toInjure[p][y][x]; + pieces[p][y][x] -= toInjure[p][y][x]; + } + } + } + } + } + + if(checkSim) { + // apply damage to map pieces + for (y = 0; y < height; y++) { + for (x = 0; x < width; x++) { + if (gameMap[y][x].strength < injureMap[y][x]) { + gameMap[y][x].strength = 0; + } else { + gameMap[y][x].strength -= injureMap[y][x] + } + gameMap[y][x].owner = 0; + } + } + + // add pieces back into the map + for (p = 0; p < numPlayers; p++) { + for (y = 0; y < height; y++) { + for (x = 0; x < width; x++) { + if (!_.isUndefined(pieces[p][y][x])) { + gameMap[y][x].owner = p + 1; + gameMap[y][x].strength = pieces[p][y][x]; + } + } + } + } + + if (frameNum + 1 < gameMap.num_frames - 1) { + if (!_.isEqual(gameMap, game.frames[frameNum + 1])) { + throw new Error("Evaluated frame did not match actual game map for frame number " + frameNum); + } + } + } + + return stats; +} + +function textToGame(text, seed) { + var startParse = new Date(); + console.log("Starting parse at", startParse); + var game = JSON.parse(text) + + if (game.version != 11) { + alert("Invalid version number: " + json_game.version); + } + + //Adds determinism (when used with https://github.com/davidbau/seedrandom) to color scramble. + console.log(seed); + + //Hardcoding colors: + var colors = [ + '0x04E6F2', + '0x424C8F',//, + '0xF577F2', + '0x23D1DE', + '0xB11243', + '0xFF704B', + '0x00B553', + '0xF8EC31' + ]; + + var x, i; + + game.players = [] + game.players.push({name: 'NULL', color: "0x888888"}); + for(i = 0; i < game.num_players; i++) { + game.players.push({name: game.player_names[i], color: colors[i] }); + console.log(game.players[game.players.length - 1].color); + } + delete game.player_names; + + console.log(game.players); + + var maxProd = 0; + for(var a = 0; a < game.height; a++) { + for(var b = 0; b < game.width; b++) { + if(game.productions[a][b] > maxProd) maxProd = game.productions[a][b]; + } + } + + game.productionNormals = [] + for(var a = 0; a < game.height; a++) { + var row = [] + for(var b = 0; b < game.width; b++) { + row.push(game.productions[a][b] / maxProd); + } + game.productionNormals.push(row) + } + + for(var a = 0; a < game.num_frames; a++) { + for(var b = 0; b < game.height; b++) { + for(var c = 0; c < game.width; c++) { + var array = game.frames[a][b][c]; + game.frames[a][b][c] = { owner: array[0], strength: array[1] }; + } + } + } + + var stats = []; + for(var a = 0; a < game.num_frames - 1; a++) { + stats[a+1] = processFrame(game, a); + } + + //Get game statistics: + for(var a = 1; a <= game.num_players; a++) { + game.players[a].territories = []; + game.players[a].productions = []; + game.players[a].strengths = []; + game.players[a].actualProduction = []; + game.players[a].playerDamageDealt = []; + game.players[a].environmentDamageDealt = []; + game.players[a].damageTaken = []; + game.players[a].capLosses = []; + + for(var b = 0; b < game.num_frames; b++) { + var ter = 0, prod = 0, str = 0; + for(var c = 0; c < game.height; c++) for(var d = 0; d < game.width; d++) { + if(game.frames[b][c][d].owner == a) { + ter++; + prod += game.productions[c][d]; + str += game.frames[b][c][d].strength; + } + } + game.players[a].territories.push(ter); + game.players[a].productions.push(prod); + game.players[a].strengths.push(str); + if (b == 0) { + game.players[a].actualProduction.push(0); + game.players[a].environmentDamageDealt.push(0); + game.players[a].damageTaken.push(0); + game.players[a].playerDamageDealt.push(0); + game.players[a].capLosses.push(0); + } + else { + game.players[a].actualProduction.push(game.players[a].actualProduction[b - 1] + stats[b][a - 1].actualProduction); + game.players[a].environmentDamageDealt.push(game.players[a].environmentDamageDealt[b - 1] + stats[b][a - 1].environmentDamageDealt); + game.players[a].damageTaken.push(game.players[a].damageTaken[b - 1] + stats[b][a - 1].damageTaken - stats[b][a - 1].environmentDamageDealt); + game.players[a].playerDamageDealt.push(game.players[a].playerDamageDealt[b - 1] + stats[b][a - 1].overkillDamage); + game.players[a].capLosses.push(game.players[a].capLosses[b - 1] + stats[b][a - 1].capLosses); + } + } + } + + //Normalize game statistics for display + var maxPlayerTer = 0, maxPlayerProd = 0, maxPlayerStr = 0, maxActProd = 0; + var maxPlrDmgDlt = 0, maxEnvDmgDlt = 0, maxDmgTkn = 0, maxCapLoss = 0; + for(var a = 1; a <= game.num_players; a++) { + for(var b = 0; b < game.num_frames; b++) { + if(game.players[a].territories[b] > maxPlayerTer) maxPlayerTer = game.players[a].territories[b]; + if(game.players[a].productions[b] > maxPlayerProd) maxPlayerProd = game.players[a].productions[b]; + if(game.players[a].strengths[b] > maxPlayerStr) maxPlayerStr = game.players[a].strengths[b]; + if(game.players[a].actualProduction[b] > maxActProd) maxActProd = game.players[a].actualProduction[b]; + if(game.players[a].playerDamageDealt[b] > maxPlrDmgDlt) maxPlrDmgDlt = game.players[a].playerDamageDealt[b]; + if(game.players[a].environmentDamageDealt[b] > maxEnvDmgDlt) maxEnvDmgDlt = game.players[a].environmentDamageDealt[b]; + if(game.players[a].damageTaken[b] > maxDmgTkn) maxDmgTkn = game.players[a].damageTaken[b]; + if(game.players[a].capLosses[b] > maxCapLoss) maxCapLoss = game.players[a].capLosses[b]; + } + } + for(var a = 1; a <= game.num_players; a++) { + game.players[a].normTers = []; + game.players[a].normProds = []; + game.players[a].normStrs = []; + game.players[a].normActProd = []; + game.players[a].normPlrDmgDlt = []; + game.players[a].normEnvDmgDlt = []; + game.players[a].normDmgTkn = []; + game.players[a].normCapLoss = []; + for(var b = 0; b < game.num_frames; b++) { + game.players[a].normTers.push(game.players[a].territories[b] / maxPlayerTer); + game.players[a].normProds.push(game.players[a].productions[b] / maxPlayerProd); + game.players[a].normStrs.push(game.players[a].strengths[b] / maxPlayerStr); + game.players[a].normActProd.push(game.players[a].actualProduction[b] / maxActProd); + game.players[a].normPlrDmgDlt.push(game.players[a].playerDamageDealt[b] / maxPlrDmgDlt); + game.players[a].normEnvDmgDlt.push(game.players[a].environmentDamageDealt[b] / maxEnvDmgDlt); + game.players[a].normDmgTkn.push(game.players[a].damageTaken[b] / maxDmgTkn); + game.players[a].normCapLoss.push(game.players[a].capLosses[b] / maxCapLoss); + } + } + + var endParse = new Date(); + console.log("Finished parse at", endParse); + console.log("Parse took", (endParse - startParse) / 1000); + return game +} diff --git a/visualize/static/visualizer.js b/visualize/static/visualizer.js new file mode 100755 index 0000000..bbca207 --- /dev/null +++ b/visualize/static/visualizer.js @@ -0,0 +1,705 @@ +var renderer; + +function initPixi() { + //Create the root of the scene: stage: + stage = new PIXI.Container(); + + // Initialize the pixi graphics class for the map: + mapGraphics = new PIXI.Graphics(); + + // Initialize the pixi graphics class for the graphs: + graphGraphics = new PIXI.Graphics(); + + // Initialize the text container; + prodContainer = new PIXI.Container(); + + possessContainer = new PIXI.Container(); + + // Initialize the text container; + strengthContainer = new PIXI.Container(); + + // Initialize the text container; + rewardContainer = new PIXI.Container(); + + // Initialize the text container; + policyContainer = new PIXI.Container(); + + renderer = PIXI.autoDetectRenderer(0, 0, { backgroundColor: 0x000000, antialias: true, transparent: true }); +} + +function showGame(game, $container, maxWidth, maxHeight, showmovement, isminimal, offline, seconds) { + if(renderer == null) initPixi(); + + $container.empty(); + + if(!isminimal) { + var $row = $("
"); + $row.append($("
")); + $row.append($("
").append($("

"+game.players.slice(1, game.num_players+1).map(function(p) { + var nameComponents = p.name.split(" "); + var name = nameComponents.slice(0, nameComponents.length-1).join(" ").trim(); + console.log(name); + var user = offline ? null : getUser(null, name); + if(user) { + return ""+p.name+"" + } else { + return ""+p.name+"" + } + }).join(" vs ")+"

"))); + $container.append($row); + } + $container.append(renderer.view); + $container.append($("
")); + + var showExtended = false; + var frame = 0; + var transit = 0; + var framespersec = seconds == null ? 3 : game.num_frames / seconds; + var shouldplay = true; + var xOffset = 0, yOffset = 0; + var zoom = 8; + if(game.num_frames / zoom < 3) zoom = game.num_frames / 3; + if(zoom < 1) zoom = 1; + function centerStartPositions() { + var minX = game.width, maxX = 0, minY = game.height, maxY = 0; + // find the initial bounding box of all players + for(var x=0; x < game.width; x++) { + for(var y=0; y < game.height; y++) { + if(game.frames[0][y][x].owner != 0) { + if(x < minX) { minX = x; } + if(x > maxX) { maxX = x; } + if(y < minY) { minY = y; } + if(y > maxY) { maxY = y; } + } + } + } + // offset by half the difference from the edges rounded toward zero + xOffset = ((game.width - 1 - maxX - minX) / 2) | 0; + yOffset = ((game.height - 1 - maxY - minY) / 2) | 0; + } + centerStartPositions(); + + discountedRewards = undefined + $.ajax({ + type: "POST", + url: '/post_discounted_rewards', + data: JSON.stringify(game), + success: function(data) {discountedRewards = JSON.parse(data)['discounted_rewards']}, + contentType: "application/json; charset=utf-8", + //dataType: "json" + }) + + policies = undefined + $.ajax({ + type: "POST", + url: '/post_policies', + data: JSON.stringify(game), + success: function(data) {policies = JSON.parse(data)['policies']}, + contentType: "application/json; charset=utf-8", + //dataType: "json" + }) + + window.onresize = function() { + var allowedWidth = (maxWidth == null ? $container.width() : maxWidth); + var allowedHeight = window.innerHeight - (25 + $("canvas").offset().top); + if(maxHeight != null) { + if(maxHeight > 0) { + allowedHeight = maxHeight - ($("canvas").offset().top - $container.offset().top); + } else { + // A negative maxHeight signifies extra space to leave for + // other page elements following the visualizer + allowedHeight += maxHeight; + } + } + + console.log(window.innerHeight) + console.log(allowedHeight) + var definingDimension = Math.min(allowedWidth, allowedHeight); + if(isminimal) { + if(allowedWidth < allowedHeight) { + sw = allowedWidth, sh = allowedWidth; + } else { + sw = allowedHeight, sh = allowedHeight; + } + mw = sh, mh = sh; + renderer.resize(sw, sh); + rw = mw / game.width, rh = mh / game.height; //Sizes of rectangles for rendering tiles. + } + else { + var splits = showExtended ? 5 : 4; + if(allowedWidth < allowedHeight*splits/3) { + sw = allowedWidth, sh = allowedWidth*3/splits; + } else { + sw = allowedHeight*splits/3, sh = allowedHeight; + } + mw = sh, mh = sh; + renderer.resize(sw, sh); + rw = mw / game.width, rh = mh / game.height; //Sizes of rectangles for rendering tiles. + if(showExtended) { + LEFT_GRAPH_LEFT = mw * 1.025, LEFT_GRAPH_RIGHT = LEFT_GRAPH_LEFT + sw * 0.17; + } else { + LEFT_GRAPH_LEFT = mw * 1.025, LEFT_GRAPH_RIGHT = sw - 1; + } + RIGHT_GRAPH_LEFT = mw * 1.35, RIGHT_GRAPH_RIGHT = RIGHT_GRAPH_LEFT + sw * 0.17; + + if(showExtended) { + TER_TOP = sh * 0.09, TER_BTM = sh * 0.29; + PROD_TOP = sh * 0.33, PROD_BTM = sh * 0.53; + STR_TOP = sh * 0.57, STR_BTM = sh * 0.77; + } else { + TER_TOP = sh * 0.09, TER_BTM = sh * 0.36; + PROD_TOP = sh * 0.41, PROD_BTM = sh * 0.675; + STR_TOP = sh * 0.725, STR_BTM = sh * 0.99; + } + + ENV_DMG_TOP = sh * 0.09, ENV_DMG_BTM = sh * 0.29; + ACT_PROD_TOP = sh * 0.33, ACT_PROD_BTM = sh * 0.53; + CAP_LOSS_TOP = sh * 0.57, CAP_LOSS_BTM = sh * 0.77; + + PLR_DMG_TOP = sh * 0.81, PLR_DMG_BTM = sh * 0.99; + DMG_TKN_TOP = sh * 0.81, DMG_TKN_BTM = sh * 0.99; + + //Create the text for rendering the terrritory, strength, and prod graphs. + stage.removeChildren(); + terText = new PIXI.Text('Territory', { font: (sh / 38).toString() + 'px Arial', fill: 0xffffff }); + terText.anchor = new PIXI.Point(0, 1); + terText.position = new PIXI.Point(mw + sh / 32, TER_TOP - sh * 0.005); + stage.addChild(terText); + prodText = new PIXI.Text('Production', { font: (sh / 38).toString() + 'px Arial', fill: 0xffffff }); + prodText.anchor = new PIXI.Point(0, 1); + prodText.position = new PIXI.Point(mw + sh / 32, PROD_TOP - sh * 0.005); + stage.addChild(prodText); + strText = new PIXI.Text('Strength', { font: (sh / 38).toString() + 'px Arial', fill: 0xffffff }); + strText.anchor = new PIXI.Point(0, 1); + strText.position = new PIXI.Point(mw + sh / 32, STR_TOP - sh * 0.005); + stage.addChild(strText); + if(showExtended) { + envDmgText = new PIXI.Text('Environment Damage', { font: (sh / 38).toString() + 'px Arial', fill: 0xffffff }); + envDmgText.anchor = new PIXI.Point(0, 1); + envDmgText.position = new PIXI.Point(mw + sh / 2.75, ENV_DMG_TOP - sh * 0.005); + stage.addChild(envDmgText); + actProdText = new PIXI.Text('Realized Production', { font: (sh / 38).toString() + 'px Arial', fill: 0xffffff }); + actProdText.anchor = new PIXI.Point(0, 1); + actProdText.position = new PIXI.Point(mw + sh / 2.75, ACT_PROD_TOP - sh * 0.005); + stage.addChild(actProdText); + capLossText = new PIXI.Text('Strength Loss to Cap', { font: (sh / 38).toString() + 'px Arial', fill: 0xffffff }); + capLossText.anchor = new PIXI.Point(0, 1); + capLossText.position = new PIXI.Point(mw + sh / 2.75, CAP_LOSS_TOP - sh * 0.005); + stage.addChild(capLossText); + plrDmgDltText = new PIXI.Text('Overkill Damage', { font: (sh / 38).toString() + 'px Arial', fill: 0xffffff }); + plrDmgDltText.anchor = new PIXI.Point(0, 1); + plrDmgDltText.position = new PIXI.Point(mw + sh / 32, PLR_DMG_TOP - sh * 0.005); + stage.addChild(plrDmgDltText); + dmgTknText = new PIXI.Text('Damage Taken', { font: (sh / 38).toString() + 'px Arial', fill: 0xffffff }); + dmgTknText.anchor = new PIXI.Point(0, 1); + dmgTknText.position = new PIXI.Point(mw + sh / 2.75, DMG_TKN_TOP - sh * 0.005); + stage.addChild(dmgTknText); + } + infoText = new PIXI.Text('Frame #' + frame.toString(), { font: (sh / 38).toString() + 'px Arial', fill: 0xffffff }); + infoText.anchor = new PIXI.Point(0, 1); + infoText.position = new PIXI.Point(mw + sh / 32, TER_TOP - sh * 0.05); + stage.addChild(infoText); + stage.addChild(graphGraphics); + } + + textStr = new Array(game.height); + textProd = new Array(game.height); + textPossess = new Array(game.height); + textReward = new Array(game.height); + textPolicy = new Array(game.height); + for (var i = 0; i < game.height; i++) { + textProd[i] = new Array(game.width); + textStr[i] = new Array(game.width); + textPossess[i] = new Array(game.width); + textReward[i] = new Array(game.width); + textPolicy[i] = new Array(game.width); + for(var j = 0; j < game.width; j++){ + textPolicy[i][j] = new Array(5); + } + } + loc=0 + + prodContainer.removeChildren() + strengthContainer.removeChildren() + possessContainer.removeChildren() + rewardContainer.removeChildren() + policyContainer.removeChildren() + var sY = Math.round(yOffset); + for(var a = 0; a < game.height; a++) { + var sX = Math.round(xOffset); + for(var b = 0; b < game.width; b++) { + var sty = new PIXI.TextStyle({ + fontFamily: 'Arial', + fontSize: 40 + }); + site = game.frames[frame][Math.floor(loc / game.width)][loc % game.width]; + textStr[a][b] = new PIXI.Text(site.strength.toString(),sty); + textStr[a][b].anchor = new PIXI.Point(0.5, +1.5); + textStr[a][b].position = new PIXI.Point(rw * (sX+0.5) , rh * (sY+0.5)); + textStr[a][b].style.fill = "#ffffff"//"#f54601"; + + textProd[a][b] = new PIXI.Text((10*game.productionNormals[Math.floor(loc / game.width)][loc % game.width]).toString(),sty) + textProd[a][b].anchor = new PIXI.Point(0.5, -0.5); + textProd[a][b].position = new PIXI.Point(rw * (sX+0.5) , rh * (sY+0.5)); + textProd[a][b].style.fill = "#ffffff"; + + textPossess[a][b] = new PIXI.Text(site.owner.toString(),sty) + textPossess[a][b].anchor = new PIXI.Point(0.5, 0.5); + textPossess[a][b].position = new PIXI.Point(rw * (sX+0.5) , rh * (sY+0.5)); + textPossess[a][b].style.fill = "#ffffff"; + + var style_1 = new PIXI.TextStyle({ + fontFamily: 'Roboto', + fontSize: 20 + }); + + textReward[a][b] = new PIXI.Text(site.owner.toString(),style_1) + textReward[a][b].anchor = new PIXI.Point(0.5, 0.5); + textReward[a][b].position = new PIXI.Point(rw * (sX+0.5) , rh * (sY+0.5)); + textReward[a][b].style.fill = "#ffffff"; + + var style_2 = new PIXI.TextStyle({ + fontFamily: 'Roboto', + fontSize: 10 + }); + + for(var j = 0; j < 5; j++){ + textPolicy[a][b][j] = new PIXI.Text(site.owner.toString(),style_2) + textPolicy[a][b][j].position = new PIXI.Point(rw * (sX+0.5) , rh * (sY+0.5)); + textPolicy[a][b][j].style.fill = "#ABD4FF"; + } + //NORTH, EAST, SOUTH, WEST, STILL + textPolicy[a][b][0].anchor = new PIXI.Point(0.5, +2.0); + textPolicy[a][b][1].anchor = new PIXI.Point(-1.0, 0.5); + textPolicy[a][b][2].anchor = new PIXI.Point(0.5, -1.0); + textPolicy[a][b][3].anchor = new PIXI.Point(2.0, 0.5); + textPolicy[a][b][4].anchor = new PIXI.Point(0.5, 0.5); + + + prodContainer.addChild(textProd[a][b]) + strengthContainer.addChild(textStr[a][b]) + possessContainer.addChild(textPossess[a][b]) + rewardContainer.addChild(textReward[a][b]) + for(var j = 0; j < 5; j++) { + policyContainer.addChild(textPolicy[a][b][j]) + } + loc++; + sX++; + if(sX == game.width) sX = 0; + } + sY++; + if(sY == game.height) sY = 0; + } + + stage.addChild(mapGraphics); + //stage.addChild(prodContainer); + //stage.addChild(strengthContainer); + //stage.addChild(possessContainer); + stage.addChild(rewardContainer); + stage.addChild(policyContainer); + console.log(renderer.width, renderer.height); + } + window.onresize(); + + var manager = new PIXI.interaction.InteractionManager(renderer); + var mousePressed = false; + document.onmousedown = function(e) { + mousePressed = true; + }; + document.onmouseup = function(e) { + mousePressed = false; + }; + + renderer.animateFunction = animate; + requestAnimationFrame(animate); + + var pressed={}; + document.onkeydown=function(e){ + e = e || window.event; + pressed[e.keyCode] = true; + if(e.keyCode == 32) { //Space + shouldplay = !shouldplay; + } + else if(e.keyCode == 69) { //e + showExtended = !showExtended; + mapGraphics.clear(); + graphGraphics.clear(); + renderer.render(stage); + window.onresize(); + } + else if(e.keyCode == 90) { //z + frame = 0; + transit = 0; + } + else if(e.keyCode == 88) { //x + frame = game.num_frames - 1; + transit = 0; + } + else if(e.keyCode == 188) { //, + if(transit == 0) frame--; + else transit = 0; + if(frame < 0) frame = 0; + shouldplay = false; + } + else if(e.keyCode == 190) { //. + frame++; + transit = 0; + if(frame >= game.num_frames - 1) frame = game.num_frames - 1; + shouldplay = false; + } + // else if(e.keyCode == 65 || e.keyCode == 68 || e.keyCode == 87 || e.keyCode == 83) { //wasd + // xOffset = Math.round(xOffset); + // yOffset = Math.round(yOffset); + // } + else if(e.keyCode == 79) { //o + xOffset = 0; + yOffset = 0; + } + else if(e.keyCode == 187 || e.keyCode == 107) { //= or + + zoom *= 1.41421356237; + if(game.num_frames / zoom < 3) zoom = game.num_frames / 3; + } + else if(e.keyCode == 189 || e.keyCode == 109) { //- or - (dash or subtract) + zoom /= 1.41421356237; + if(zoom < 1) zoom = 1; + } + else if(e.keyCode == 49) { //1 + framespersec = 1; + } + else if(e.keyCode == 50) { //2 + framespersec = 3; + } + else if(e.keyCode == 51) { //3 + framespersec = 6; + } + else if(e.keyCode == 52) { //4 + framespersec = 10; + } + else if(e.keyCode == 53) { //5 + framespersec = 15; + } + } + + document.onkeyup=function(e){ + e = e || window.event; + delete pressed[e.keyCode]; + } + + var lastTime = Date.now(); + + function interpolate(c1, c2, v) { + var c = { r: v * c2.r + (1 - v) * c1.r, g: v * c2.g + (1 - v) * c1.g, b: v * c2.b + (1- v) * c1.b }; + function compToHex(c) { var hex = c.toString(16); return hex.length == 1 ? "0" + hex : hex; }; + return "0x" + compToHex(Math.round(c.r)) + compToHex(Math.round(c.g)) + compToHex(Math.round(c.b)); + } + + function animate() { + + if(renderer.animateFunction !== animate) { return; } + + if(!isminimal) { + //Clear graphGraphics so that we can redraw freely. + graphGraphics.clear(); + + //Draw the graphs. + var nf = Math.round(game.num_frames / zoom), graphMidFrame = frame; + var nf2 = Math.floor(nf / 2); + if(graphMidFrame + nf2 >= game.num_frames) graphMidFrame -= ((nf2 + graphMidFrame) - game.num_frames); + else if(Math.ceil(graphMidFrame - nf2) < 0) graphMidFrame = nf2; + var firstFrame = graphMidFrame - nf2, lastFrame = graphMidFrame + nf2; + if(firstFrame < 0) firstFrame = 0; + if(lastFrame >= game.num_frames) lastFrame = game.num_frames - 1; + nf = lastFrame - firstFrame; + var dw = (LEFT_GRAPH_RIGHT - LEFT_GRAPH_LEFT) / (nf); + //Normalize values with respect to the range of frames seen by the graph. + var maxTer = 0, maxProd = 0, maxStr = 0, maxActProd = 0; + var maxPlrDmgDlt = 0, maxEnvDmgDlt = 0, maxDmgTkn = 0, maxCapLoss = 0; + for(var a = 1; a <= game.num_players; a++) { + for(var b = firstFrame; b <= lastFrame; b++) { + if(game.players[a].territories[b] > maxTer) maxTer = game.players[a].territories[b] * 1.01; + if(game.players[a].productions[b] > maxProd) maxProd = game.players[a].productions[b] * 1.01; + if(game.players[a].strengths[b] > maxStr) maxStr = game.players[a].strengths[b] * 1.01; + if(game.players[a].actualProduction[b] > maxActProd) maxActProd = game.players[a].actualProduction[b] * 1.01; + if(game.players[a].playerDamageDealt[b] > maxPlrDmgDlt) maxPlrDmgDlt = game.players[a].playerDamageDealt[b] * 1.01; + if(game.players[a].environmentDamageDealt[b] > maxEnvDmgDlt) maxEnvDmgDlt = game.players[a].environmentDamageDealt[b] * 1.01; + if(game.players[a].damageTaken[b] > maxDmgTkn) maxDmgTkn = game.players[a].damageTaken[b] * 1.01; + if(game.players[a].capLosses[b] > maxCapLoss) maxCapLoss = game.players[a].capLosses[b] * 1.01; + } + } + function drawGraph(left, top, bottom, data, maxData) { + graphGraphics.moveTo(left, (top - bottom) * data[firstFrame] / maxData + bottom); + for(var b = firstFrame + 1; b <= lastFrame; b++) { + graphGraphics.lineTo(left + dw * (b - firstFrame), (top - bottom) * data[b] / maxData + bottom); + } + } + for(var a = 1; a <= game.num_players; a++) { + graphGraphics.lineStyle(1, game.players[a].color); + //Draw ter graph. + drawGraph(LEFT_GRAPH_LEFT, TER_TOP, TER_BTM, game.players[a].territories, maxTer); + //Draw prod graph. + drawGraph(LEFT_GRAPH_LEFT, PROD_TOP, PROD_BTM, game.players[a].productions, maxProd); + //Draw str graph. + drawGraph(LEFT_GRAPH_LEFT, STR_TOP, STR_BTM, game.players[a].strengths, maxStr); + if(showExtended) { + //Draw env dmg graph. + drawGraph(RIGHT_GRAPH_LEFT, ENV_DMG_TOP, ENV_DMG_BTM, game.players[a].environmentDamageDealt, maxEnvDmgDlt); + //Draw act prod graph. + drawGraph(RIGHT_GRAPH_LEFT, ACT_PROD_TOP, ACT_PROD_BTM, game.players[a].actualProduction, maxActProd); + //Draw str loss graph. + drawGraph(RIGHT_GRAPH_LEFT, CAP_LOSS_TOP, CAP_LOSS_BTM, game.players[a].capLosses, maxCapLoss); + //Draw plr dmg dealt. + drawGraph(LEFT_GRAPH_LEFT, PLR_DMG_TOP, PLR_DMG_BTM, game.players[a].playerDamageDealt, maxPlrDmgDlt); + //Draw damage taken. + drawGraph(RIGHT_GRAPH_LEFT, DMG_TKN_TOP, DMG_TKN_BTM, game.players[a].damageTaken, maxDmgTkn); + } + } + //Draw borders. + graphGraphics.lineStyle(1, '0xffffff'); + function drawGraphBorder(left, right, top, bottom) { + graphGraphics.moveTo(left + dw * (frame - firstFrame), top); + graphGraphics.lineTo(left + dw * (frame - firstFrame), bottom); + if((frame - firstFrame) > 0) graphGraphics.lineTo(left, bottom); //Deals with odd disappearing line.; + graphGraphics.lineTo(left, top); + graphGraphics.lineTo(right, top); + graphGraphics.lineTo(right, bottom); + graphGraphics.lineTo(left + dw * (frame - firstFrame), bottom); + } + + //Draw ter border. + drawGraphBorder(LEFT_GRAPH_LEFT, LEFT_GRAPH_RIGHT, TER_TOP, TER_BTM); + //Draw prod border. + drawGraphBorder(LEFT_GRAPH_LEFT, LEFT_GRAPH_RIGHT, PROD_TOP, PROD_BTM); + //Draw str border. + drawGraphBorder(LEFT_GRAPH_LEFT, LEFT_GRAPH_RIGHT, STR_TOP, STR_BTM); + if(showExtended) { + //Draw env dmg border. + drawGraphBorder(RIGHT_GRAPH_LEFT, RIGHT_GRAPH_RIGHT, ENV_DMG_TOP, ENV_DMG_BTM); + //Draw act prod border. + drawGraphBorder(RIGHT_GRAPH_LEFT, RIGHT_GRAPH_RIGHT, ACT_PROD_TOP, ACT_PROD_BTM); + //Draw str loss border. + drawGraphBorder(RIGHT_GRAPH_LEFT, RIGHT_GRAPH_RIGHT, CAP_LOSS_TOP, CAP_LOSS_BTM); + //Draw plr damage dealt. + drawGraphBorder(LEFT_GRAPH_LEFT, LEFT_GRAPH_RIGHT, PLR_DMG_TOP, PLR_DMG_BTM); + //Draw plr damage taken. + drawGraphBorder(RIGHT_GRAPH_LEFT, RIGHT_GRAPH_RIGHT, DMG_TKN_TOP, DMG_TKN_BTM); + } + //Draw frame/ter text seperator. + graphGraphics.moveTo(LEFT_GRAPH_LEFT, TER_TOP - sh * 0.045); + graphGraphics.lineTo(RIGHT_GRAPH_RIGHT, TER_TOP - sh * 0.045); + } + + //Clear mapGraphics so that we can redraw freely. + mapGraphics.clear(); + + if(pressed[80]) { //Render productions. Don't update frames or transits. [Using p now for testing] + var loc = 0; + var pY = Math.round(yOffset); + for(var a = 0; a < game.height; a++) { + var pX = Math.round(xOffset); + for(var b = 0; b < game.width; b++) { + // VISU + if(game.productionNormals[Math.floor(loc / game.width)][loc % game.width] < 0.33333) mapGraphics.beginFill(interpolate({ r: 40, g: 40, b: 40 }, { r: 128, g: 80, b: 144 }, game.productionNormals[Math.floor(loc / game.width)][loc % game.width] * 3)); + else if(game.productionNormals[Math.floor(loc / game.width)][loc % game.width] < 0.66667) mapGraphics.beginFill(interpolate({ r: 128, g: 80, b: 144 }, { r: 176, g: 48, b: 48 }, game.productionNormals[Math.floor(loc / game.width)][loc % game.width] * 3 - 1)); + else mapGraphics.beginFill(interpolate({ r: 176, g: 48, b: 48 }, { r: 255, g: 240, b: 16 }, game.productionNormals[Math.floor(loc / game.width)][loc % game.width] * 3 - 2)); + mapGraphics.drawRect(rw * pX, rh * pY, rw, rh); + mapGraphics.endFill(); + loc++; + pX++; + if(pX == game.width) pX = 0; + } + pY++; + if(pY == game.height) pY = 0; + } + } + else { //Render game and update frames and transits. + var loc = 0; + var tY = Math.round(yOffset); + for(var a = 0; a < game.height; a++) { + var tX = Math.round(xOffset); + for(var b = 0; b < game.width; b++) { + var site = game.frames[frame][Math.floor(loc / game.width)][loc % game.width]; + mapGraphics.beginFill(game.players[site.owner].color, game.productionNormals[Math.floor(loc / game.width)][loc % game.width] * 0.4 + 0.15); + mapGraphics.drawRect(rw * tX, rh * tY, rw, rh); // Production + mapGraphics.endFill(); + loc++; + tX++; + if(tX == game.width) tX = 0; + } + tY++; + if(tY == game.height) tY = 0; + } + + var t = showmovement ? (-Math.cos(transit * Math.PI) + 1) / 2 : 0; + loc = 0; + var sY = Math.round(yOffset); + for(var a = 0; a < game.height; a++) { + var sX = Math.round(xOffset); + for(var b = 0; b < game.width; b++) { + var site = game.frames[frame][Math.floor(loc / game.width)][loc % game.width]; + if(site.strength == 255) mapGraphics.lineStyle(1, '0xfffff0'); + if(site.strength != 0) mapGraphics.beginFill(game.players[site.owner].color); + + textStr[a][b].text = site.strength.toString() + textPossess[a][b].text = site.owner.toString() + textProd[a][b].style.fill = (site.owner.toString()=="1")?"#04e6f2":"#ffffff"; + + textReward[a][b].text =(pressed[65] && discountedRewards!= undefined && frame!=lastFrame && site.owner.toString()=="1")?discountedRewards[frame][Math.floor(loc / game.width)][loc % game.width].toPrecision(2):''; + + + //policies[a][b].text = policies[frame][a][b] In fact there are five... + //var debug_direction = ["NORTH", "EAST", "SOUTH", "WEST", "STILL"] + for(var i = 0; i < 5; i++) { + //var value = (policies!= undefined)?policies[frame][a][b][i].toExponential(1):0 + textPolicy[a][b][i].text = '' //(value==0)?'':value.toString() + } + + //console.log(discounted_rewards_function[frame][Math.floor(loc / game.width)][loc % game.width]) + var pw = rw * Math.sqrt(site.strength > 0 ? site.strength / 255 : 0.1) / 2 + var ph = rh * Math.sqrt(site.strength > 0 ? site.strength / 255 : 0.1) / 2; + var direction = frame < game.moves.length ? game.moves[frame][Math.floor(loc / game.width)][loc % game.width] : 0; + var move = t > 0 ? direction : 0; + var sY2 = move == 1 ? sY - 1 : move == 3 ? sY + 1 : sY; + var sX2 = move == 2 ? sX + 1 : move == 4 ? sX - 1 : sX; + if(site.strength == 0 && direction != 0) mapGraphics.lineStyle(1, '0x888888') + var center = new PIXI.Point(rw * ((t * sX2 + (1 - t) * sX) + 0.5), rh * ((t * sY2 + (1 - t) * sY) + 0.5)); + var pts = new Array(); + const squarescale = 0.75; + pts.push(new PIXI.Point(center.x + squarescale * pw, center.y + squarescale * ph)); + pts.push(new PIXI.Point(center.x + squarescale * pw, center.y - squarescale * ph)); + pts.push(new PIXI.Point(center.x - squarescale * pw, center.y - squarescale * ph)); + pts.push(new PIXI.Point(center.x - squarescale * pw, center.y + squarescale * ph)); + mapGraphics.drawPolygon(pts); + if(site.strength != 0) mapGraphics.endFill(); + mapGraphics.lineStyle(0, '0xffffff'); + loc++; + sX++; + if(sX == game.width) sX = 0; + } + sY++; + if(sY == game.height) sY = 0; + } + + var time = Date.now(); + var dt = time - lastTime; + lastTime = time; + + // If we are embedding a game, + // we want people to be able to scroll with + // the arrow keys + if(!isminimal) { + //Update frames per sec if up or down arrows are pressed. + if(pressed[38]) { + framespersec += 0.05; + } else if(pressed[40]) { + framespersec -= 0.05; + } + } + + if(pressed[39]) { + transit = 0; + frame++; + } + else if(pressed[37]) { + if(transit != 0) transit = 0; + else frame--; + } + else if(shouldplay) { + transit += dt / 1000 * framespersec; + } + } + + if(!isminimal) { + //Update info text: + var mousepos = manager.mouse.global; + if(mousepos.x < 0 || mousepos.x > sw || mousepos.y < 0 || mousepos.y > sh) { //Mouse is not over renderer. + infoText.text = 'Frame #' + frame.toString(); + } + else if(!mousePressed) { + infoText.text = 'Frame #' + frame.toString(); + if(mousepos.x < mw && mousepos.y < mh) { //Over map + var x = (Math.floor(mousepos.x / rw) - xOffset) % game.width, y = (Math.floor(mousepos.y / rh) - yOffset) % game.height; + if(x < 0) x += game.width; + if(y < 0) y += game.height; + infoText.text += ' | Loc: ' + x.toString() + ',' + y.toString(); + } + } + else { //Mouse is clicked and over renderer. + if(mousepos.x < mw && mousepos.y < mh) { //Over map: + var x = (Math.floor(mousepos.x / rw) - xOffset) % game.width, y = (Math.floor(mousepos.y / rh) - yOffset) % game.height; + if(x < 0) x += game.width; + if(y < 0) y += game.height; + str = game.frames[frame][y][x].strength; + prod = game.productions[y][x]; + infoText.text = 'Str: ' + str.toString() + ' | Prod: ' + prod.toString(); + + //policies[a][b].text = policies[frame][a][b] In fact there are five... + var debug_direction = ["NORTH", "EAST", "SOUTH", "WEST", "STILL"] + for(var i = 0; i < 5; i++) { + var value = (policies != undefined) ? policies[frame][y][x][i].toExponential(1) : 0 + textPolicy[y][x][i].text = (value == 0) ? '' : value.toString() + } + if(pressed[85]){//u pressed + textReward[y][x].text =(discountedRewards!= undefined && frame!=lastFrame)?discountedRewards[frame][y][x].toPrecision(2):''; + } + + + if(frame < game.moves.length && game.frames[frame][y][x].owner != 0) { + move = game.moves[frame][y][x]; + if(move >= 0 && move < 5) { + move = "0NESW"[move]; + } + infoText.text += ' | Mv: ' + move.toString(); + } + } + else if(mousepos.x < RIGHT_GRAPH_RIGHT && mousepos.x > LEFT_GRAPH_LEFT) { + frame = firstFrame + Math.round((mousepos.x - LEFT_GRAPH_LEFT) / dw); + if(frame < 0) frame = 0; + if(frame >= game.num_frames) frame = game.num_frames - 1; + transit = 0; + if(mousepos.y > TER_TOP & mousepos.y < TER_BTM) { + } + } + } + } + + //Advance frame if transit moves far enough. Ensure all are within acceptable bounds. + while(transit >= 1) { + transit--; + frame++; + } + if(frame >= game.num_frames - 1) { + frame = game.num_frames - 1; + transit = 0; + } + while(transit < 0) { + transit++; + frame--; + } + if(frame < 0) { + frame = 0; + transit = 0; + } + + //Pan if desired. + const PAN_SPEED = 1; + // if(pressed[65]) xOffset += PAN_SPEED; + // if(pressed[68]) xOffset -= PAN_SPEED + // if(pressed[87]) yOffset += PAN_SPEED; + // if(pressed[83]) yOffset -= PAN_SPEED; + + //Reset pan to be in normal bounds: + if(Math.round(xOffset) >= game.width) xOffset -= game.width; + else if(Math.round(xOffset) < 0) xOffset += game.width; + if(Math.round(yOffset) >= game.height) yOffset -= game.height; + else if(Math.round(yOffset) < 0) yOffset += game.height; + + //Actually render. + renderer.render(stage); + + //Of course, we want to render in the future as well. + var idle = (Object.keys(pressed).length === 0) && !shouldplay; + setTimeout(function() { + requestAnimationFrame(animate); + }, 1000 / (idle ? 20.0 : 80.0)); + } +} \ No newline at end of file diff --git a/visualize/templates/performance.html b/visualize/templates/performance.html new file mode 100755 index 0000000..586ef5a --- /dev/null +++ b/visualize/templates/performance.html @@ -0,0 +1,25 @@ + + + + + Visualizer + + + + + +
+
+
+ + +
+
+
+ \ No newline at end of file diff --git a/visualize/templates/visualizer.html b/visualize/templates/visualizer.html new file mode 100755 index 0000000..402e54c --- /dev/null +++ b/visualize/templates/visualizer.html @@ -0,0 +1,48 @@ + + + + + Visualizer + + + + + +
+
+
+

Drag your file

+

Go to performance plot

+
+
+ +
+
+
+ +
+
+ +
+
+
+
    + {%- for item in tree.children recursive %} +
  • {{ item.name }}
  • + {%- endfor %} +
+
+
+ + + + + + + + + + + \ No newline at end of file diff --git a/visualize/visualize.py b/visualize/visualize.py new file mode 100755 index 0000000..7d42fb4 --- /dev/null +++ b/visualize/visualize.py @@ -0,0 +1,142 @@ +"""The visualize main file to launch the server""" +import json +import os +import sys +from io import BytesIO + +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +import numpy as np +import pandas as pd +from flask import Flask, render_template, request, make_response, send_from_directory + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +try: + from train.reward.reward import Reward + from public.state.state import State1 + from public.models.strategy.TrainedStrategy import TrainedStrategy +except: + raise + +app = Flask(__name__) + +hlt_root = os.path.join(app.root_path, 'hlt') + + +@app.route('/hlt/') +def send_hlt(path): + return send_from_directory('hlt', path) + + +@app.route("/") +def home(): + return render_template('visualizer.html', tree=make_tree(hlt_root)) + + +@app.route("/performance.html") +def performance(): + """ + Return the page for the performance + :return: + """ + return render_template('performance.html') + + +def make_tree(path): + """ + For finding the halite file, we provide their directory tree. + :return: + """ + tree = dict(name=os.path.basename(path), children=[]) + try: + lst = os.listdir(path) + except OSError: + pass + else: + for name in lst: + fn = os.path.join(path, name) + if os.path.isdir(fn): + tree['children'].append(make_tree(fn)) + else: + if name not in [".DS_Store", "README.md"]: + tree['children'].append(dict(path='hlt/' + name, name=name)) + print(np) + return tree + + +@app.route("/performance.png") +def performance_plot(): + """ + Plot the performance at this address + :return: + """ + fig = Figure() + sub1 = fig.add_subplot(111) + path_to_variables = os.path.abspath(os.path.dirname(__file__)) + '/../public/models/variables/' + list_variables = [name for name in os.listdir(path_to_variables) if name not in [".DS_Store", "README.md"]] + + path_to_npy = [path_to_variables + name + '/' + name + '.npy' for name in list_variables] + + rewards = [np.load(path) for path in path_to_npy] + + max_len = max([len(reward) for reward in rewards]) + for i, reward in enumerate(rewards): + rewards[i] = np.append(reward, np.repeat(np.nan, max_len - len(reward))) + + pd.DataFrame(np.array(rewards).T, columns=list_variables).rolling(100).mean().plot( + title="Weighted reward at each game. (Rolling average)", ax=sub1) + + plt.show() + canvas = FigureCanvas(fig) + png_output = BytesIO() + canvas.print_png(png_output) + response = make_response(png_output.getvalue()) + response.headers['Content-Type'] = 'image/png' + return response + + +def convert(r): + """ + Convert the r to the game_states/moves tuple. + :param r: + :return: + """ + + def get_owner(square): + return square['owner'] + + def get_strength(square): + return square['strength'] + + get_owner = np.vectorize(get_owner) + get_strength = np.vectorize(get_strength) + owner_frames = get_owner(r.json["frames"])[:, np.newaxis, :, :] + strength_frames = get_strength(r.json["frames"])[:, np.newaxis, :, :] + production_frames = np.repeat(np.array(r.json["productions"])[np.newaxis, np.newaxis, :, :], + len(owner_frames), + axis=0) + + moves = np.array(r.json['moves']) + + game_states = np.concatenate(([owner_frames, strength_frames, production_frames]), axis=1) + + moves = ((-5 + 5 * game_states[:-1, 0, :]) + ((moves - 1) % 5)) + return game_states, moves + + +@app.route('/post_discounted_rewards', methods=['POST']) +def post_discounted_rewards(): + game_states, moves = convert(request) + r = Reward(State1(scope=2)) + discounted_rewards = r.discounted_rewards_function(game_states, moves) + return json.dumps({'discounted_rewards': discounted_rewards.tolist()}) + + +@app.route('/post_policies', methods=['POST']) +def post_policies(): + game_states, _ = convert(request) + bot = TrainedStrategy() + bot.init_session() + policies = bot.get_policies(game_states) + return json.dumps({'policies': policies.tolist()})