TPFA_ResSim.plotting

Convenient plot functions for reservoir model.

  1"""Convenient plot functions for reservoir model."""
  2
  3import matplotlib as mpl
  4import matplotlib.pyplot as plt
  5import numpy as np
  6from matplotlib.ticker import MultipleLocator
  7from mpl_tools import place, place_ax
  8from mpl_tools.misc import axprops
  9
 10coord_type = "absolute"
 11"""Define scaling of `Plot2D.plt_field` axes.
 12- "relative": `(0, 1)  x (0, 1)`
 13- "absolute": `(0, Lx) x (0, Ly)`
 14- "index"   : `(0, Ny) x (0, Ny)`
 15"""
 16
 17# Colormap for saturation
 18lin_cm = mpl.colors.LinearSegmentedColormap.from_list
 19cm_ow = lin_cm("", [(0, "#1d9e97"), (.3, "#b2e0dc"), (1, "#f48974")])
 20# cOil, cWater = "red", "blue"        # Plain
 21# cOil, cWater = "#d8345f", "#01a9b4" # Pastel/neon
 22# cOil, cWater = "#e58a8a", "#086972" # Pastel
 23# ccnvrt = lambda c: np.array(mpl.colors.colorConverter.to_rgb(c))
 24# cMiddle = .3*ccnvrt(cWater) + .7*ccnvrt(cOil)
 25# cm_ow = lin_cm("", [cWater, cMiddle, cOil])
 26
 27
 28styles = dict(
 29    default = dict(
 30        title  = "",
 31        transf = lambda x: x,
 32        cmap   = "viridis",
 33        levels = 10,
 34        cticks = None,
 35        # Note that providing vmin/vmax (and not a levels list) to mpl
 36        # yields prettier colobar ticks, but destorys the consistency
 37        # of the colorbars from one figure to another.
 38        locator = None,
 39    ),
 40    oil = dict(
 41        title  = "Oil saturation",
 42        transf = lambda x: 1 - x,
 43        cmap   = cm_ow,
 44        levels = np.linspace(0 - 1e-7, 1 + 1e-7, 20),
 45        cticks = np.linspace(0, 1, 6),
 46    ),
 47)
 48"""Default `Plot2D.plt_field` plot styling values."""
 49
 50
 51class Plot2D:
 52    """Plots specialized for 2D fields."""
 53
 54    def plt_field(self, ax, Z, style="default", wells=True,
 55                  argmax=False, colorbar=True, labels=True, grid=False,
 56                  finalize=True, **kwargs):
 57        """Contour-plot of the (flat) unravelled field `Z`.
 58
 59        `kwargs` falls back to `styles[style]`, which falls back to `styles['defaults']`.
 60        """
 61        # Populate kwargs with fallback style
 62        kwargs = {**styles["default"], **styles[style], **kwargs}
 63        # Pop from kwargs. Remainder goes to countourf
 64        ax.set(**axprops(kwargs))
 65        cticks = kwargs.pop("cticks")
 66
 67        # Why extent=(0, Lx, 0, Ly), rather than merely changing ticks?
 68        # set_aspect("equal") and mouse hovering (reporting x,y).
 69        if "rel" in coord_type:
 70            Lx, Ly = 1, 1
 71        elif "abs" in coord_type:
 72            Lx, Ly = self.Lx, self.Ly
 73        elif "ind" in coord_type:
 74            Lx, Ly = self.Nx, self.Ny
 75        else:
 76            raise ValueError(f"Unsupported coord_type: {coord_type}")
 77
 78        # Apply transform
 79        Z = np.asarray(Z)
 80        Z = kwargs.pop("transf")(Z)
 81
 82        # Need to transpose coz orientation is model.shape==(Nx, Ny),
 83        # while contour() displays the same orientation as array printing.
 84        Z = Z.reshape(self.shape).T
 85
 86        # Did we bother to specify set_over/set_under/set_bad ?
 87        has_out_of_range = getattr(kwargs["cmap"], "_rgba_over", None) is not None
 88
 89        # Unlike `ax.imshow(Z[::-1])`, `contourf` does not simply fill pixels/cells (but
 90        # it does provide nice interpolation!) so there will be whitespace on the margins.
 91        # No fix is needed, and anyway it would not be trivial/fast,
 92        # ref https://github.com/matplotlib/basemap/issues/406 .
 93        collections = ax.contourf(
 94            Z, **kwargs,
 95            # origin=None,  # ⇒ NB: falsely stretches the field!!!
 96            origin="lower",
 97            extent=(0, Lx, 0, Ly),
 98            extend="both" if has_out_of_range else "neither",
 99        )
100
101        # Contourf does not plot (at all) the bad regions. "Fake it" by facecolor
102        if has_out_of_range:
103            ax.set_facecolor(getattr(kwargs["cmap"], "_rgba_bad", "w"))
104
105        # Grid (reflecting the model grid)
106        # NB: If not showing grid, then don't locate ticks on grid, because they're
107        #     generally uglier that mpl's default/automatic tick location. But, it
108        #     should be safe to go with 'g' format instead of 'f'.
109        ax.xaxis.set_major_formatter('{x:.3g}')
110        ax.yaxis.set_major_formatter('{x:.3g}')
111        ax.tick_params(which='minor', length=0, color='r')
112        ax.tick_params(which='major', width=1.5, direction="in")
113        if grid:
114            n1 = 10
115            xStep = 1 + self.Nx//n1
116            yStep = 1 + self.Ny//n1
117            ax.xaxis.set_major_locator(MultipleLocator(self.hx*xStep))
118            ax.yaxis.set_major_locator(MultipleLocator(self.hy*yStep))
119            ax.xaxis.set_minor_locator(MultipleLocator(self.hx))
120            ax.yaxis.set_minor_locator(MultipleLocator(self.hy))
121            ax.grid(True, which="both")
122
123        # Axis lims
124        ax.set_xlim((0, Lx))
125        ax.set_ylim((0, Ly))
126        ax.set_aspect("equal")
127        # Axis labels
128        if labels:
129            if "abs" in coord_type:
130                ax.set_xlabel("x")
131                ax.set_ylabel("y")
132            else:
133                ax.set_xlabel(f"x ({coord_type})")
134                ax.set_ylabel(f"y ({coord_type})")
135
136        # Add well markers
137        if wells:
138            if wells == "color":
139                wells = {"color": [f"C{i}" for i in range(len(self.prd_xy))]}
140            elif wells in [True, 1]:
141                wells = {}
142            self.well_scatter(ax, self.prd_xy, False, **wells)
143            wells.pop("color", None)
144            self.well_scatter(ax, self.inj_xy, True, **wells)
145
146        # Add argmax marker
147        if argmax:
148            idx = Z.T.argmax()  # reverse above transpose
149            xy = self.ind2xy(idx)
150            for c, ms in zip(['b', 'r', 'y'],
151                             [10, 6, 3]):
152                ax.plot(*xy, "o", c=c, ms=ms, label="max", zorder=98)
153
154        # Add colorbar
155        if colorbar:
156            if isinstance(colorbar, type(ax)):
157                cax = dict(cax=colorbar)
158            else:
159                cax = dict(ax=ax, shrink=.8)
160            ax.figure.colorbar(collections, **cax, ticks=cticks)
161
162        tight_show(ax.figure, finalize)
163        return collections
164
165
166    def well_scatter(self, ax, ww, inj=True, text=None, color=None, size=1):
167        """Scatter-plot the wells of `ww` onto a `Plot2D.plt_field`."""
168        # Well coordinates
169        ww = self.sub2xy(*self.xy2sub(*ww.T)).T
170        # NB: make sure ww array data is not overwritten (avoid in-place)
171        if   "rel" in coord_type: s = 1/self.Lx, 1/self.Ly                     # noqa
172        elif "abs" in coord_type: s = 1, 1                                     # noqa
173        elif "ind" in coord_type: s = self.Nx/self.Lx, self.Ny/self.Ly         # noqa
174        else: raise ValueError("Unsupported coordinate type: %s" % coord_type) # noqa
175        ww = ww * s
176
177        # Style
178        if inj:
179            c  = "w"
180            ec = "gray"
181            d  = "k"
182            m  = "v"
183        else:
184            c  = "k"
185            ec = "gray"
186            d  = "w"
187            m  = "^"
188
189        if color:
190            c = color
191
192        # Markers
193        sh = ax.plot(*ww.T, 'r.', ms=3, clip_on=False)
194        sh = ax.scatter(*ww.T, s=(size * 26)**2, c=c, marker=m, ec=ec,
195                        clip_on=False,
196                        zorder=1.5,  # required on Jupypter
197                        )
198
199        # Text labels
200        if text is not False:
201            for i, w in enumerate(ww):
202                if not inj:
203                    w[1] -= 0.01
204                ax.text(*w[:2], i if text is None else text,
205                        color=d, fontsize=size*12, ha="center", va="center")
206
207        return sh
208
209    def plt_production(self, ax, production, obs=None,
210                       legend_outside=True, finalize=True):
211        """Production time series. Multiple wells in 1 axes => not ensemble compat."""
212        hh = []
213        tt = 1+np.arange(len(production))
214        for i, p in enumerate(1-production.T):
215            hh += ax.plot(tt, p, "-", label=i)
216
217        if obs is not None:
218            for i, y in enumerate(1-obs.T):
219                ax.plot(tt, y, "*", c=hh[i].get_color())
220
221        # Add legend
222        if legend_outside:
223            kws = dict(
224                  bbox_to_anchor=(1, 1),
225                  loc="upper left",
226                  ncol=1+len(production.T)//10,
227            )
228        else:
229            kws = dict(loc="lower left")
230        ax.legend(title="Well #.", **kws)
231
232        ax.set_title("Oil saturation in producers")
233        ax.set_xlabel("Time index")
234        # ax.set_ylim(-0.01, 1.01)
235        ax.axhline(0, c="xkcd:light grey", ls="--", zorder=1.8)
236        ax.axhline(1, c="xkcd:light grey", ls="--", zorder=1.8)
237
238        tight_show(ax.figure, finalize)
239        return hh
240
241    # Note: See note in mpl_setup.py about properly displaying the animation.
242    def anim(self, wsats, prod, title="", figsize=(10, 3.5), pause=200, animate=True,
243             **kwargs):
244        """Animate the saturation and production time series."""
245
246        # Create figure and axes
247        title = "Animation" + ("-- " + title if title else "")
248        fig, (ax1, ax2) = place.freshfig(title, ncols=2, figsize=figsize,
249                                         gridspec_kw=dict(width_ratios=(2, 3)))
250        fig.suptitle(title)  # coz animation never (any backend) displays title
251        # Saturations
252        kwargs.update(wells="color", colorbar=True, finalize=False)
253        ax2.cc = self.plt_field(ax2, wsats[-1], "oil", **kwargs)
254        # Production
255        hh = self.plt_production(ax1, prod, legend_outside=False, finalize=False)
256        fig.tight_layout()
257
258        if animate:
259            from matplotlib import animation
260            tt = np.arange(len(wsats))
261
262            def update_fig(iT):
263                # Update field
264                for c in ax2.cc.collections:
265                    try:
266                        ax2.collections.remove(c)
267                    except (AttributeError, ValueError):
268                        pass  # occurs when re-running script
269                kwargs.update(wells=False, colorbar=False)
270                ax2.cc = self.plt_field(ax2, wsats[iT], "oil", **kwargs)
271
272                # Update production lines
273                if iT >= 1:
274                    for h, p in zip(hh, prod.T):
275                        h.set_data(tt[1:1+iT], 1 - p[:iT])
276
277            ani = animation.FuncAnimation(
278                fig, update_fig, len(tt), blit=False, interval=pause,
279                # Prevent busy/idle indicator constantly flashing, despite %%capture
280                # and even manually clearing the output of the calling cell.
281                repeat=False,  # flashing stops once the (unshown) animation finishes.
282                # An alternative solution is to do this in the next cell:
283                # animation.event_source.stop()
284                # but it does not work if using "run all", even with time.sleep(1).
285            )
286
287            return ani
288
289def tight_show(figure, enabled):
290    if enabled:
291        figure.tight_layout()
292        plt.show()
coord_type = 'absolute'

Define scaling of Plot2D.plt_field axes.

  • "relative": (0, 1) x (0, 1)
  • "absolute": (0, Lx) x (0, Ly)
  • "index" : (0, Ny) x (0, Ny)
@staticmethod
def lin_cm(name, colors, N=256, gamma=1.0):
1037    @staticmethod
1038    def from_list(name, colors, N=256, gamma=1.0):
1039        """
1040        Create a `LinearSegmentedColormap` from a list of colors.
1041
1042        Parameters
1043        ----------
1044        name : str
1045            The name of the colormap.
1046        colors : array-like of colors or array-like of (value, color)
1047            If only colors are given, they are equidistantly mapped from the
1048            range :math:`[0, 1]`; i.e. 0 maps to ``colors[0]`` and 1 maps to
1049            ``colors[-1]``.
1050            If (value, color) pairs are given, the mapping is from *value*
1051            to *color*. This can be used to divide the range unevenly.
1052        N : int
1053            The number of RGB quantization levels.
1054        gamma : float
1055        """
1056        if not np.iterable(colors):
1057            raise ValueError('colors must be iterable')
1058
1059        if (isinstance(colors[0], Sized) and len(colors[0]) == 2
1060                and not isinstance(colors[0], str)):
1061            # List of value, color pairs
1062            vals, colors = zip(*colors)
1063        else:
1064            vals = np.linspace(0, 1, len(colors))
1065
1066        r, g, b, a = to_rgba_array(colors).T
1067        cdict = {
1068            "red": np.column_stack([vals, r, r]),
1069            "green": np.column_stack([vals, g, g]),
1070            "blue": np.column_stack([vals, b, b]),
1071            "alpha": np.column_stack([vals, a, a]),
1072        }
1073
1074        return LinearSegmentedColormap(name, cdict, N, gamma)

Create a LinearSegmentedColormap from a list of colors.

Parameters

name : str The name of the colormap. colors : array-like of colors or array-like of (value, color) If only colors are given, they are equidistantly mapped from the range \( [0, 1] \); i.e. 0 maps to colors[0] and 1 maps to colors[-1]. If (value, color) pairs are given, the mapping is from value to color. This can be used to divide the range unevenly. N : int The number of RGB quantization levels. gamma : float

cm_ow = <matplotlib.colors.LinearSegmentedColormap object>
styles = {'default': {'title': '', 'transf': <function <lambda>>, 'cmap': 'viridis', 'levels': 10, 'cticks': None, 'locator': None}, 'oil': {'title': 'Oil saturation', 'transf': <function <lambda>>, 'cmap': <matplotlib.colors.LinearSegmentedColormap object>, 'levels': array([-1.00000000e-07, 5.26314895e-02, 1.05263079e-01, 1.57894668e-01, 2.10526258e-01, 2.63157847e-01, 3.15789437e-01, 3.68421026e-01, 4.21052616e-01, 4.73684205e-01, 5.26315795e-01, 5.78947384e-01, 6.31578974e-01, 6.84210563e-01, 7.36842153e-01, 7.89473742e-01, 8.42105332e-01, 8.94736921e-01, 9.47368511e-01, 1.00000010e+00]), 'cticks': array([0. , 0.2, 0.4, 0.6, 0.8, 1. ])}}

Default Plot2D.plt_field plot styling values.

class Plot2D:
 52class Plot2D:
 53    """Plots specialized for 2D fields."""
 54
 55    def plt_field(self, ax, Z, style="default", wells=True,
 56                  argmax=False, colorbar=True, labels=True, grid=False,
 57                  finalize=True, **kwargs):
 58        """Contour-plot of the (flat) unravelled field `Z`.
 59
 60        `kwargs` falls back to `styles[style]`, which falls back to `styles['defaults']`.
 61        """
 62        # Populate kwargs with fallback style
 63        kwargs = {**styles["default"], **styles[style], **kwargs}
 64        # Pop from kwargs. Remainder goes to countourf
 65        ax.set(**axprops(kwargs))
 66        cticks = kwargs.pop("cticks")
 67
 68        # Why extent=(0, Lx, 0, Ly), rather than merely changing ticks?
 69        # set_aspect("equal") and mouse hovering (reporting x,y).
 70        if "rel" in coord_type:
 71            Lx, Ly = 1, 1
 72        elif "abs" in coord_type:
 73            Lx, Ly = self.Lx, self.Ly
 74        elif "ind" in coord_type:
 75            Lx, Ly = self.Nx, self.Ny
 76        else:
 77            raise ValueError(f"Unsupported coord_type: {coord_type}")
 78
 79        # Apply transform
 80        Z = np.asarray(Z)
 81        Z = kwargs.pop("transf")(Z)
 82
 83        # Need to transpose coz orientation is model.shape==(Nx, Ny),
 84        # while contour() displays the same orientation as array printing.
 85        Z = Z.reshape(self.shape).T
 86
 87        # Did we bother to specify set_over/set_under/set_bad ?
 88        has_out_of_range = getattr(kwargs["cmap"], "_rgba_over", None) is not None
 89
 90        # Unlike `ax.imshow(Z[::-1])`, `contourf` does not simply fill pixels/cells (but
 91        # it does provide nice interpolation!) so there will be whitespace on the margins.
 92        # No fix is needed, and anyway it would not be trivial/fast,
 93        # ref https://github.com/matplotlib/basemap/issues/406 .
 94        collections = ax.contourf(
 95            Z, **kwargs,
 96            # origin=None,  # ⇒ NB: falsely stretches the field!!!
 97            origin="lower",
 98            extent=(0, Lx, 0, Ly),
 99            extend="both" if has_out_of_range else "neither",
100        )
101
102        # Contourf does not plot (at all) the bad regions. "Fake it" by facecolor
103        if has_out_of_range:
104            ax.set_facecolor(getattr(kwargs["cmap"], "_rgba_bad", "w"))
105
106        # Grid (reflecting the model grid)
107        # NB: If not showing grid, then don't locate ticks on grid, because they're
108        #     generally uglier that mpl's default/automatic tick location. But, it
109        #     should be safe to go with 'g' format instead of 'f'.
110        ax.xaxis.set_major_formatter('{x:.3g}')
111        ax.yaxis.set_major_formatter('{x:.3g}')
112        ax.tick_params(which='minor', length=0, color='r')
113        ax.tick_params(which='major', width=1.5, direction="in")
114        if grid:
115            n1 = 10
116            xStep = 1 + self.Nx//n1
117            yStep = 1 + self.Ny//n1
118            ax.xaxis.set_major_locator(MultipleLocator(self.hx*xStep))
119            ax.yaxis.set_major_locator(MultipleLocator(self.hy*yStep))
120            ax.xaxis.set_minor_locator(MultipleLocator(self.hx))
121            ax.yaxis.set_minor_locator(MultipleLocator(self.hy))
122            ax.grid(True, which="both")
123
124        # Axis lims
125        ax.set_xlim((0, Lx))
126        ax.set_ylim((0, Ly))
127        ax.set_aspect("equal")
128        # Axis labels
129        if labels:
130            if "abs" in coord_type:
131                ax.set_xlabel("x")
132                ax.set_ylabel("y")
133            else:
134                ax.set_xlabel(f"x ({coord_type})")
135                ax.set_ylabel(f"y ({coord_type})")
136
137        # Add well markers
138        if wells:
139            if wells == "color":
140                wells = {"color": [f"C{i}" for i in range(len(self.prd_xy))]}
141            elif wells in [True, 1]:
142                wells = {}
143            self.well_scatter(ax, self.prd_xy, False, **wells)
144            wells.pop("color", None)
145            self.well_scatter(ax, self.inj_xy, True, **wells)
146
147        # Add argmax marker
148        if argmax:
149            idx = Z.T.argmax()  # reverse above transpose
150            xy = self.ind2xy(idx)
151            for c, ms in zip(['b', 'r', 'y'],
152                             [10, 6, 3]):
153                ax.plot(*xy, "o", c=c, ms=ms, label="max", zorder=98)
154
155        # Add colorbar
156        if colorbar:
157            if isinstance(colorbar, type(ax)):
158                cax = dict(cax=colorbar)
159            else:
160                cax = dict(ax=ax, shrink=.8)
161            ax.figure.colorbar(collections, **cax, ticks=cticks)
162
163        tight_show(ax.figure, finalize)
164        return collections
165
166
167    def well_scatter(self, ax, ww, inj=True, text=None, color=None, size=1):
168        """Scatter-plot the wells of `ww` onto a `Plot2D.plt_field`."""
169        # Well coordinates
170        ww = self.sub2xy(*self.xy2sub(*ww.T)).T
171        # NB: make sure ww array data is not overwritten (avoid in-place)
172        if   "rel" in coord_type: s = 1/self.Lx, 1/self.Ly                     # noqa
173        elif "abs" in coord_type: s = 1, 1                                     # noqa
174        elif "ind" in coord_type: s = self.Nx/self.Lx, self.Ny/self.Ly         # noqa
175        else: raise ValueError("Unsupported coordinate type: %s" % coord_type) # noqa
176        ww = ww * s
177
178        # Style
179        if inj:
180            c  = "w"
181            ec = "gray"
182            d  = "k"
183            m  = "v"
184        else:
185            c  = "k"
186            ec = "gray"
187            d  = "w"
188            m  = "^"
189
190        if color:
191            c = color
192
193        # Markers
194        sh = ax.plot(*ww.T, 'r.', ms=3, clip_on=False)
195        sh = ax.scatter(*ww.T, s=(size * 26)**2, c=c, marker=m, ec=ec,
196                        clip_on=False,
197                        zorder=1.5,  # required on Jupypter
198                        )
199
200        # Text labels
201        if text is not False:
202            for i, w in enumerate(ww):
203                if not inj:
204                    w[1] -= 0.01
205                ax.text(*w[:2], i if text is None else text,
206                        color=d, fontsize=size*12, ha="center", va="center")
207
208        return sh
209
210    def plt_production(self, ax, production, obs=None,
211                       legend_outside=True, finalize=True):
212        """Production time series. Multiple wells in 1 axes => not ensemble compat."""
213        hh = []
214        tt = 1+np.arange(len(production))
215        for i, p in enumerate(1-production.T):
216            hh += ax.plot(tt, p, "-", label=i)
217
218        if obs is not None:
219            for i, y in enumerate(1-obs.T):
220                ax.plot(tt, y, "*", c=hh[i].get_color())
221
222        # Add legend
223        if legend_outside:
224            kws = dict(
225                  bbox_to_anchor=(1, 1),
226                  loc="upper left",
227                  ncol=1+len(production.T)//10,
228            )
229        else:
230            kws = dict(loc="lower left")
231        ax.legend(title="Well #.", **kws)
232
233        ax.set_title("Oil saturation in producers")
234        ax.set_xlabel("Time index")
235        # ax.set_ylim(-0.01, 1.01)
236        ax.axhline(0, c="xkcd:light grey", ls="--", zorder=1.8)
237        ax.axhline(1, c="xkcd:light grey", ls="--", zorder=1.8)
238
239        tight_show(ax.figure, finalize)
240        return hh
241
242    # Note: See note in mpl_setup.py about properly displaying the animation.
243    def anim(self, wsats, prod, title="", figsize=(10, 3.5), pause=200, animate=True,
244             **kwargs):
245        """Animate the saturation and production time series."""
246
247        # Create figure and axes
248        title = "Animation" + ("-- " + title if title else "")
249        fig, (ax1, ax2) = place.freshfig(title, ncols=2, figsize=figsize,
250                                         gridspec_kw=dict(width_ratios=(2, 3)))
251        fig.suptitle(title)  # coz animation never (any backend) displays title
252        # Saturations
253        kwargs.update(wells="color", colorbar=True, finalize=False)
254        ax2.cc = self.plt_field(ax2, wsats[-1], "oil", **kwargs)
255        # Production
256        hh = self.plt_production(ax1, prod, legend_outside=False, finalize=False)
257        fig.tight_layout()
258
259        if animate:
260            from matplotlib import animation
261            tt = np.arange(len(wsats))
262
263            def update_fig(iT):
264                # Update field
265                for c in ax2.cc.collections:
266                    try:
267                        ax2.collections.remove(c)
268                    except (AttributeError, ValueError):
269                        pass  # occurs when re-running script
270                kwargs.update(wells=False, colorbar=False)
271                ax2.cc = self.plt_field(ax2, wsats[iT], "oil", **kwargs)
272
273                # Update production lines
274                if iT >= 1:
275                    for h, p in zip(hh, prod.T):
276                        h.set_data(tt[1:1+iT], 1 - p[:iT])
277
278            ani = animation.FuncAnimation(
279                fig, update_fig, len(tt), blit=False, interval=pause,
280                # Prevent busy/idle indicator constantly flashing, despite %%capture
281                # and even manually clearing the output of the calling cell.
282                repeat=False,  # flashing stops once the (unshown) animation finishes.
283                # An alternative solution is to do this in the next cell:
284                # animation.event_source.stop()
285                # but it does not work if using "run all", even with time.sleep(1).
286            )
287
288            return ani

Plots specialized for 2D fields.

def plt_field( self, ax, Z, style='default', wells=True, argmax=False, colorbar=True, labels=True, grid=False, finalize=True, **kwargs):
 55    def plt_field(self, ax, Z, style="default", wells=True,
 56                  argmax=False, colorbar=True, labels=True, grid=False,
 57                  finalize=True, **kwargs):
 58        """Contour-plot of the (flat) unravelled field `Z`.
 59
 60        `kwargs` falls back to `styles[style]`, which falls back to `styles['defaults']`.
 61        """
 62        # Populate kwargs with fallback style
 63        kwargs = {**styles["default"], **styles[style], **kwargs}
 64        # Pop from kwargs. Remainder goes to countourf
 65        ax.set(**axprops(kwargs))
 66        cticks = kwargs.pop("cticks")
 67
 68        # Why extent=(0, Lx, 0, Ly), rather than merely changing ticks?
 69        # set_aspect("equal") and mouse hovering (reporting x,y).
 70        if "rel" in coord_type:
 71            Lx, Ly = 1, 1
 72        elif "abs" in coord_type:
 73            Lx, Ly = self.Lx, self.Ly
 74        elif "ind" in coord_type:
 75            Lx, Ly = self.Nx, self.Ny
 76        else:
 77            raise ValueError(f"Unsupported coord_type: {coord_type}")
 78
 79        # Apply transform
 80        Z = np.asarray(Z)
 81        Z = kwargs.pop("transf")(Z)
 82
 83        # Need to transpose coz orientation is model.shape==(Nx, Ny),
 84        # while contour() displays the same orientation as array printing.
 85        Z = Z.reshape(self.shape).T
 86
 87        # Did we bother to specify set_over/set_under/set_bad ?
 88        has_out_of_range = getattr(kwargs["cmap"], "_rgba_over", None) is not None
 89
 90        # Unlike `ax.imshow(Z[::-1])`, `contourf` does not simply fill pixels/cells (but
 91        # it does provide nice interpolation!) so there will be whitespace on the margins.
 92        # No fix is needed, and anyway it would not be trivial/fast,
 93        # ref https://github.com/matplotlib/basemap/issues/406 .
 94        collections = ax.contourf(
 95            Z, **kwargs,
 96            # origin=None,  # ⇒ NB: falsely stretches the field!!!
 97            origin="lower",
 98            extent=(0, Lx, 0, Ly),
 99            extend="both" if has_out_of_range else "neither",
100        )
101
102        # Contourf does not plot (at all) the bad regions. "Fake it" by facecolor
103        if has_out_of_range:
104            ax.set_facecolor(getattr(kwargs["cmap"], "_rgba_bad", "w"))
105
106        # Grid (reflecting the model grid)
107        # NB: If not showing grid, then don't locate ticks on grid, because they're
108        #     generally uglier that mpl's default/automatic tick location. But, it
109        #     should be safe to go with 'g' format instead of 'f'.
110        ax.xaxis.set_major_formatter('{x:.3g}')
111        ax.yaxis.set_major_formatter('{x:.3g}')
112        ax.tick_params(which='minor', length=0, color='r')
113        ax.tick_params(which='major', width=1.5, direction="in")
114        if grid:
115            n1 = 10
116            xStep = 1 + self.Nx//n1
117            yStep = 1 + self.Ny//n1
118            ax.xaxis.set_major_locator(MultipleLocator(self.hx*xStep))
119            ax.yaxis.set_major_locator(MultipleLocator(self.hy*yStep))
120            ax.xaxis.set_minor_locator(MultipleLocator(self.hx))
121            ax.yaxis.set_minor_locator(MultipleLocator(self.hy))
122            ax.grid(True, which="both")
123
124        # Axis lims
125        ax.set_xlim((0, Lx))
126        ax.set_ylim((0, Ly))
127        ax.set_aspect("equal")
128        # Axis labels
129        if labels:
130            if "abs" in coord_type:
131                ax.set_xlabel("x")
132                ax.set_ylabel("y")
133            else:
134                ax.set_xlabel(f"x ({coord_type})")
135                ax.set_ylabel(f"y ({coord_type})")
136
137        # Add well markers
138        if wells:
139            if wells == "color":
140                wells = {"color": [f"C{i}" for i in range(len(self.prd_xy))]}
141            elif wells in [True, 1]:
142                wells = {}
143            self.well_scatter(ax, self.prd_xy, False, **wells)
144            wells.pop("color", None)
145            self.well_scatter(ax, self.inj_xy, True, **wells)
146
147        # Add argmax marker
148        if argmax:
149            idx = Z.T.argmax()  # reverse above transpose
150            xy = self.ind2xy(idx)
151            for c, ms in zip(['b', 'r', 'y'],
152                             [10, 6, 3]):
153                ax.plot(*xy, "o", c=c, ms=ms, label="max", zorder=98)
154
155        # Add colorbar
156        if colorbar:
157            if isinstance(colorbar, type(ax)):
158                cax = dict(cax=colorbar)
159            else:
160                cax = dict(ax=ax, shrink=.8)
161            ax.figure.colorbar(collections, **cax, ticks=cticks)
162
163        tight_show(ax.figure, finalize)
164        return collections

Contour-plot of the (flat) unravelled field Z.

kwargs falls back to styles[style], which falls back to styles['defaults'].

def well_scatter(self, ax, ww, inj=True, text=None, color=None, size=1):
167    def well_scatter(self, ax, ww, inj=True, text=None, color=None, size=1):
168        """Scatter-plot the wells of `ww` onto a `Plot2D.plt_field`."""
169        # Well coordinates
170        ww = self.sub2xy(*self.xy2sub(*ww.T)).T
171        # NB: make sure ww array data is not overwritten (avoid in-place)
172        if   "rel" in coord_type: s = 1/self.Lx, 1/self.Ly                     # noqa
173        elif "abs" in coord_type: s = 1, 1                                     # noqa
174        elif "ind" in coord_type: s = self.Nx/self.Lx, self.Ny/self.Ly         # noqa
175        else: raise ValueError("Unsupported coordinate type: %s" % coord_type) # noqa
176        ww = ww * s
177
178        # Style
179        if inj:
180            c  = "w"
181            ec = "gray"
182            d  = "k"
183            m  = "v"
184        else:
185            c  = "k"
186            ec = "gray"
187            d  = "w"
188            m  = "^"
189
190        if color:
191            c = color
192
193        # Markers
194        sh = ax.plot(*ww.T, 'r.', ms=3, clip_on=False)
195        sh = ax.scatter(*ww.T, s=(size * 26)**2, c=c, marker=m, ec=ec,
196                        clip_on=False,
197                        zorder=1.5,  # required on Jupypter
198                        )
199
200        # Text labels
201        if text is not False:
202            for i, w in enumerate(ww):
203                if not inj:
204                    w[1] -= 0.01
205                ax.text(*w[:2], i if text is None else text,
206                        color=d, fontsize=size*12, ha="center", va="center")
207
208        return sh

Scatter-plot the wells of ww onto a Plot2D.plt_field.

def plt_production(self, ax, production, obs=None, legend_outside=True, finalize=True):
210    def plt_production(self, ax, production, obs=None,
211                       legend_outside=True, finalize=True):
212        """Production time series. Multiple wells in 1 axes => not ensemble compat."""
213        hh = []
214        tt = 1+np.arange(len(production))
215        for i, p in enumerate(1-production.T):
216            hh += ax.plot(tt, p, "-", label=i)
217
218        if obs is not None:
219            for i, y in enumerate(1-obs.T):
220                ax.plot(tt, y, "*", c=hh[i].get_color())
221
222        # Add legend
223        if legend_outside:
224            kws = dict(
225                  bbox_to_anchor=(1, 1),
226                  loc="upper left",
227                  ncol=1+len(production.T)//10,
228            )
229        else:
230            kws = dict(loc="lower left")
231        ax.legend(title="Well #.", **kws)
232
233        ax.set_title("Oil saturation in producers")
234        ax.set_xlabel("Time index")
235        # ax.set_ylim(-0.01, 1.01)
236        ax.axhline(0, c="xkcd:light grey", ls="--", zorder=1.8)
237        ax.axhline(1, c="xkcd:light grey", ls="--", zorder=1.8)
238
239        tight_show(ax.figure, finalize)
240        return hh

Production time series. Multiple wells in 1 axes => not ensemble compat.

def anim( self, wsats, prod, title='', figsize=(10, 3.5), pause=200, animate=True, **kwargs):
243    def anim(self, wsats, prod, title="", figsize=(10, 3.5), pause=200, animate=True,
244             **kwargs):
245        """Animate the saturation and production time series."""
246
247        # Create figure and axes
248        title = "Animation" + ("-- " + title if title else "")
249        fig, (ax1, ax2) = place.freshfig(title, ncols=2, figsize=figsize,
250                                         gridspec_kw=dict(width_ratios=(2, 3)))
251        fig.suptitle(title)  # coz animation never (any backend) displays title
252        # Saturations
253        kwargs.update(wells="color", colorbar=True, finalize=False)
254        ax2.cc = self.plt_field(ax2, wsats[-1], "oil", **kwargs)
255        # Production
256        hh = self.plt_production(ax1, prod, legend_outside=False, finalize=False)
257        fig.tight_layout()
258
259        if animate:
260            from matplotlib import animation
261            tt = np.arange(len(wsats))
262
263            def update_fig(iT):
264                # Update field
265                for c in ax2.cc.collections:
266                    try:
267                        ax2.collections.remove(c)
268                    except (AttributeError, ValueError):
269                        pass  # occurs when re-running script
270                kwargs.update(wells=False, colorbar=False)
271                ax2.cc = self.plt_field(ax2, wsats[iT], "oil", **kwargs)
272
273                # Update production lines
274                if iT >= 1:
275                    for h, p in zip(hh, prod.T):
276                        h.set_data(tt[1:1+iT], 1 - p[:iT])
277
278            ani = animation.FuncAnimation(
279                fig, update_fig, len(tt), blit=False, interval=pause,
280                # Prevent busy/idle indicator constantly flashing, despite %%capture
281                # and even manually clearing the output of the calling cell.
282                repeat=False,  # flashing stops once the (unshown) animation finishes.
283                # An alternative solution is to do this in the next cell:
284                # animation.event_source.stop()
285                # but it does not work if using "run all", even with time.sleep(1).
286            )
287
288            return ani

Animate the saturation and production time series.

def tight_show(figure, enabled):
290def tight_show(figure, enabled):
291    if enabled:
292        figure.tight_layout()
293        plt.show()