diff --git a/shodan/__main__.py b/shodan/__main__.py
index 4093b94..d4ec5ea 100644
--- a/shodan/__main__.py
+++ b/shodan/__main__.py
@@ -37,6 +37,7 @@
import requests
import time
import json
+from shodan.cli.validation import check_input_file_type, check_filename_filepath, check_not_null
# The file converters that are used to go from .json.gz to various other formats
from shodan.cli.converter import CsvConverter, KmlConverter, GeoJsonConverter, ExcelConverter, ImagesConverter
@@ -93,7 +94,7 @@ def main():
@main.command()
@click.option('--fields', help='List of properties to output.', default=None)
-@click.argument('input', metavar=' ', type=click.Path(exists=True))
+@click.argument('input', metavar=' ', type=click.Path(exists=True), callback=check_input_file_type)
@click.argument('format', metavar='', type=click.Choice(CONVERTERS.keys()))
def convert(fields, input, format):
"""Convert the given input data file into a different format. The following file formats are supported:
@@ -263,8 +264,8 @@ def count(query):
@main.command()
@click.option('--fields', help='Specify the list of properties to download instead of grabbing the full banner', default=None, type=str)
@click.option('--limit', help='The number of results you want to download. -1 to download all the data possible.', default=1000, type=int)
-@click.argument('filename', metavar='')
-@click.argument('query', metavar='', nargs=-1)
+@click.argument('filename', metavar='', callback=check_filename_filepath)
+@click.argument('query', metavar='', nargs=-1, callback=check_not_null)
def download(fields, limit, filename, query):
"""Download search results and save them in a compressed JSON file."""
key = get_api_key()
@@ -272,14 +273,6 @@ def download(fields, limit, filename, query):
# Create the query string out of the provided tuple
query = ' '.join(query).strip()
- # Make sure the user didn't supply an empty string
- if query == '':
- raise click.ClickException('Empty search query')
-
- filename = filename.strip()
- if filename == '':
- raise click.ClickException('Empty filename')
-
# Add the appropriate extension if it's not there atm
if not filename.endswith('.json.gz'):
filename += '.json.gz'
@@ -471,7 +464,7 @@ def myip(ipv6):
@click.option('--fields', help='List of properties to show in the search results.', default='ip_str,port,hostnames,data')
@click.option('--limit', help='The number of search results that should be returned. Maximum: 1000', default=100, type=int)
@click.option('--separator', help='The separator between the properties of the search results.', default='\t')
-@click.argument('query', metavar='', nargs=-1)
+@click.argument('query', metavar='', nargs=-1, callback=check_not_null)
def search(color, fields, limit, separator, query):
"""Search the Shodan database"""
key = get_api_key()
@@ -479,10 +472,6 @@ def search(color, fields, limit, separator, query):
# Create the query string out of the provided tuple
query = ' '.join(query).strip()
- # Make sure the user didn't supply an empty string
- if query == '':
- raise click.ClickException('Empty search query')
-
# For now we only allow up to 1000 results at a time
if limit > 1000:
raise click.ClickException('Too many results requested, maximum is 1,000')
@@ -542,7 +531,7 @@ def search(color, fields, limit, separator, query):
@click.option('--limit', help='The number of results to return.', default=10, type=int)
@click.option('--facets', help='List of facets to get statistics for.', default='country,org')
@click.option('--filename', '-O', help='Save the results in a CSV file of the provided name.', default=None)
-@click.argument('query', metavar='', nargs=-1)
+@click.argument('query', metavar='', nargs=-1, callback=check_not_null)
def stats(limit, facets, filename, query):
"""Provide summary information about a search query"""
# Setup Shodan
@@ -809,7 +798,7 @@ def _create_stream(name, args, timeout):
@click.option('--facets', help='List of facets to get summary information on, if empty then show query total results over time', default='', type=str)
@click.option('--filename', '-O', help='Save the full results in the given file (append if file exists).', default=None)
@click.option('--save', '-S', help='Save the full results in the a file named after the query (append if file exists).', default=False, is_flag=True)
-@click.argument('query', metavar='', nargs=-1)
+@click.argument('query', metavar='', nargs=-1, callback=check_not_null)
def trends(filename, save, facets, query):
"""Search Shodan historical database"""
key = get_api_key()
@@ -819,10 +808,6 @@ def trends(filename, save, facets, query):
query = ' '.join(query).strip()
facets = facets.strip()
- # Make sure the user didn't supply an empty query or facets
- if query == '':
- raise click.ClickException('Empty search query')
-
# Convert comma-separated facets string to list
parsed_facets = []
for facet in facets.split(','):
diff --git a/shodan/cli/validation.py b/shodan/cli/validation.py
new file mode 100644
index 0000000..183e37a
--- /dev/null
+++ b/shodan/cli/validation.py
@@ -0,0 +1,49 @@
+import click
+from os import path
+
+
+def check_not_null(ctx, param, value):
+ """
+ Click callback method used to verify command line parameter is not an empty string.
+ :param ctx: Python Click library Context object.
+ :param param: Python Click Context object params attribute.
+ :param value: Value passed in for a given command line parameter.
+ """
+ if not value:
+ raise click.BadParameter("Value cannot be empty / null")
+ return value
+
+
+def check_input_file_type(ctx, param, value):
+ """
+ Click callback method used for file type input validation.
+ :param ctx: Python Click library Context object.
+ :param param: Python Click Context object params attribute.
+ :param value: Value passed in for a given command line parameter.
+ """
+ idx = value.find(".")
+
+ if idx == -1 or value[idx:] != ".json.gz":
+ raise click.BadParameter("Input file type must be '.json.gz'")
+ return value
+
+
+def check_filename_filepath(ctx, param, value):
+ """
+ Click callback method used for file path input validation.
+ :param ctx: Python Click library Context object.
+ :param param: Python Click Context object params attribute.
+ :param value: Value passed in for a given command line parameter.
+ """
+ filename = value.strip()
+ folder_idx = filename.rfind('/')
+
+ if filename == '':
+ raise click.click.BadParameter('Empty filename')
+
+ if folder_idx != -1:
+ parent_folder = filename[0: folder_idx + 1]
+ if not path.exists(parent_folder):
+ raise click.BadParameter('File path does not exist.')
+
+ return value