source: grass-addons/grass7/raster/r.sample.category/r.sample.category.py

Last change on this file was 73540, checked in by neteler, 6 years ago

r.sample.category addon: minor keyword standardization

  • Property svn:eol-style set to native
  • Property svn:mime-type set to text/x-python
File size: 7.1 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4##############################################################################
5#
6# MODULE: r.sample.category
7# AUTHOR(S): Vaclav Petras <wenzeslaus gmail com>
8# Anna Petrasova <kratochanna gmail com>
9# PURPOSE: Sample points from each category
10# COPYRIGHT: (C) 2015 by Vaclav Petras, and the GRASS Development Team
11#
12# This program is free software under the GNU General Public
13# License (>=v2). Read the COPYING file that comes with GRASS
14# for details.
15#
16##############################################################################
17
18
19#%module
20#% description: Create sampling points from each category in a raster map
21#% keyword: raster
22#% keyword: sampling
23#% keyword: random
24#% keyword: points
25#% keyword: vector
26#% keyword: stratified random sampling
27#%end
28#%option G_OPT_R_INPUT
29#% description: Name of input raster map with categories (classes)
30#%end
31#%option G_OPT_V_OUTPUT
32#% description: Name of output vector map with points at random locations
33#%end
34#%option G_OPT_R_INPUTS
35#% description: Names of input raster maps to be sampled
36#% key: sampled
37#% required: no
38#%end
39#%option
40#% label: Number of sampling points per category in the input map
41#% description: You can provide multiple numbers, one for each category in input raster (sorted ascending)
42#% key: npoints
43#% required: yes
44#% multiple: yes
45#% type: integer
46#%end
47#%flag
48#% key: s
49#% description: If number of cells in category < npoints, skip category
50#%end
51
52# TODO: Python tests for more advanced things such as overwrite or attributes
53# TODO: only optional sampling of the category raster
54# TODO: preserver original raster categories as vector point categories
55# TODO: specify number of points and distribute them uniformly
56# TODO: specify number of points and distribute them according to histogram
57# TODO: ensure/check minimum and maximum number of of points when doing histogram
58# TODO: create function to check for mask
59# TODO: move escape and mask functions to library
60
61import os
62import atexit
63
64import grass.script as gscript
65
66
67TMP = []
68
69
70def cleanup():
71 if gscript.find_file(name='MASK', element='cell', mapset=gscript.gisenv()['MAPSET'])['name']:
72 gscript.run_command('r.mask', flags='r', quiet=True)
73 if TMP:
74 gscript.run_command('g.remove', flags='f', type=['raster', 'vector'], name=TMP, quiet=True)
75
76
77def escape_sql_column(name):
78 """Escape string to create a safe name of column for SQL
79
80 >>> escape_sql_column("elevation.10m")
81 elevation_10m
82 """
83 name = name.replace('.', '_')
84 return name
85
86
87def strip_mapset(name):
88 """Strip Mapset name and '@' from map name
89
90 >>> strip_mapset('elevation@PERMANENT')
91 elevation
92 """
93 if '@' in name:
94 return name.split('@')[0]
95 return name
96
97
98def main():
99 options, flags = gscript.parser()
100
101 input_raster = options['input']
102 points = options['output']
103 if options['sampled']:
104 sampled_rasters = options['sampled'].split(',')
105 else:
106 sampled_rasters = []
107 npoints = [int(num) for num in options['npoints'].split(',')]
108 flag_s = flags['s']
109
110 if gscript.find_file(name='MASK', element='cell', mapset=gscript.gisenv()['MAPSET'])['name']:
111 gscript.fatal(_("MASK is active. Please remove it before proceeding."))
112
113 # we clean up mask too, so register after we know that mask is not present
114 atexit.register(cleanup)
115
116 temp_name = 'tmp_r_sample_category_{}_'.format(os.getpid())
117 points_nocats = temp_name + 'points_nocats'
118 TMP.append(points_nocats)
119
120 # input must be CELL
121 rdescribe = gscript.read_command('r.stats', flags='ln', input=input_raster, separator='pipe')
122 catlab = rdescribe.splitlines()
123 categories = map(int, [z.split('|')[0] for z in catlab])
124 catlab = dict([z.split('|') for z in catlab])
125 if len(npoints) == 1:
126 npoints = npoints * len(categories)
127 else:
128 if len(categories) != len(npoints):
129 gscript.fatal(_("Number of categories in raster does not match the number of provided sampling points numbers."))
130
131 # Create sample points per category
132 vectors = []
133 for i, cat in enumerate(categories):
134 # skip generating points if none are required
135 if npoints[i] == 0:
136 continue
137 gscript.info(_("Selecting {n} sampling locations at category {cat}...").format(n=npoints[i], cat=cat))
138 # change mask to sample zeroes and then change again to sample ones
139 # overwrite mask for an easy loop
140 gscript.run_command('r.mask', raster=input_raster, maskcats=cat, overwrite=True, quiet=True)
141
142 # Check number of cells in category
143 nrc = int(gscript.parse_command('r.univar', map=input_raster, flags='g')['n'])
144 if nrc < npoints[i]:
145 if flag_s:
146 gscript.info(_("Not enough points in category {cat}. Skipping").format(cat=categories[i]))
147 continue
148 gscript.warning(_("Number of raster cells in category {cat} < {np}. Sampling {n} points").format(cat=categories[i], np=npoints[i], n=nrc))
149 npoints[i] = nrc
150
151 # Create the points
152 vector = temp_name + str(cat)
153 vectors.append(vector)
154 gscript.run_command('r.random', input=input_raster, npoints=npoints[i], vector=vector, quiet=True)
155 TMP.append(vector)
156 gscript.run_command('r.mask', flags='r', quiet=True)
157
158 gscript.run_command('v.patch', input=vectors, output=points, quiet=True)
159 # remove and add gain cats so that they are unique
160 gscript.run_command('v.category', input=points, option='del', cat=-1, output=points_nocats, quiet=True)
161 # overwrite to reuse the map
162 gscript.run_command('v.category', input=points_nocats, option='add', output=points, overwrite=True, quiet=True)
163
164 # Sample layers
165 columns = []
166 column_names = []
167 sampled_rasters.insert(0, input_raster)
168 for raster in sampled_rasters:
169 column = escape_sql_column(strip_mapset(raster).lower())
170 column_names.append(column)
171 datatype = gscript.parse_command('r.info', flags='g', map=raster)['datatype']
172 if datatype == 'CELL':
173 datatype = 'integer'
174 else:
175 datatype = 'double precision'
176 columns.append("{column} {datatype}".format(column=column, datatype=datatype))
177 gscript.run_command('v.db.addtable', map=points, columns=','.join(columns), quiet=True)
178 for raster, column in zip(sampled_rasters, column_names):
179 gscript.info(_("Sampling raster map %s...") % raster)
180 gscript.run_command('v.what.rast', map=points, type='point', raster=raster, column=column, quiet=True)
181
182 # Add category labels
183 if not list(set(catlab.values()))[0] and len(set(catlab.values())) == 1:
184 gscript.verbose(_("There are no category labels in the raster to add"))
185 else:
186 gscript.run_command("v.db.addcolumn", map=points, columns="label varchar(250)")
187 table_name = escape_sql_column(strip_mapset(points).lower())
188 for i in categories:
189 sqlstat = "UPDATE " + table_name + " SET label='" + catlab[str(i)] + "' WHERE " + column_names[0] + " == " + str(i)
190 gscript.run_command("db.execute", sql=sqlstat)
191
192
193if __name__ == '__main__':
194 main()
Note: See TracBrowser for help on using the repository browser.