oracle/
batch.rs

1// Rust-oracle - Rust binding for Oracle database
2//
3// URL: https://github.com/kubo/rust-oracle
4//
5// ------------------------------------------------------
6//
7// Copyright 2021 Kubo Takehiro <kubo@jiubao.org>
8//
9// Redistribution and use in source and binary forms, with or without modification, are
10// permitted provided that the following conditions are met:
11//
12//    1. Redistributions of source code must retain the above copyright notice, this list of
13//       conditions and the following disclaimer.
14//
15//    2. Redistributions in binary form must reproduce the above copyright notice, this list
16//       of conditions and the following disclaimer in the documentation and/or other materials
17//       provided with the distribution.
18//
19// THIS SOFTWARE IS PROVIDED BY THE AUTHORS ''AS IS'' AND ANY EXPRESS OR IMPLIED
20// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
21// FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> OR
22// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
25// ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
27// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28//
29// The views and conclusions contained in the software and documentation are those of the
30// authors and should not be interpreted as representing official policies, either expressed
31// or implied, of the authors.
32
33use crate::chkerr;
34use crate::error::DPI_ERR_BUFFER_SIZE_TOO_SMALL;
35use crate::private;
36use crate::sql_type::OracleType;
37use crate::sql_type::ToSql;
38use crate::sql_value::BufferRowIndex;
39use crate::statement::QueryParams;
40use crate::to_rust_str;
41use crate::Connection;
42use crate::DbError;
43use crate::Error;
44use crate::OdpiStr;
45use crate::Result;
46use crate::SqlValue;
47#[cfg(doc)]
48use crate::Statement;
49use crate::StatementType;
50use odpic_sys::*;
51use std::convert::TryFrom;
52use std::fmt;
53use std::mem::MaybeUninit;
54use std::os::raw::c_char;
55use std::ptr;
56use std::slice;
57
58#[cfg(test)]
59const MINIMUM_TYPE_LENGTH: u32 = 1;
60#[cfg(not(test))]
61const MINIMUM_TYPE_LENGTH: u32 = 64;
62
63// round up to the nearest power of two
64fn po2(mut size: u32) -> u32 {
65    if size < MINIMUM_TYPE_LENGTH {
66        size = MINIMUM_TYPE_LENGTH;
67    }
68    1u32 << (32 - (size - 1).leading_zeros())
69}
70
71fn oratype_size(oratype: &OracleType) -> Option<u32> {
72    match oratype {
73        &OracleType::Varchar2(size)
74        | &OracleType::NVarchar2(size)
75        | &OracleType::Char(size)
76        | &OracleType::NChar(size)
77        | &OracleType::Raw(size) => Some(size),
78        _ => None,
79    }
80}
81
82#[derive(Clone)]
83struct BindType {
84    oratype: Option<OracleType>,
85}
86
87impl BindType {
88    fn new(oratype: &OracleType) -> BindType {
89        BindType {
90            oratype: match oratype {
91                OracleType::Varchar2(size) => Some(OracleType::Varchar2(po2(*size))),
92                OracleType::NVarchar2(size) => Some(OracleType::NVarchar2(po2(*size))),
93                OracleType::Char(size) => Some(OracleType::Char(po2(*size))),
94                OracleType::NChar(size) => Some(OracleType::NChar(po2(*size))),
95                OracleType::Raw(size) => Some(OracleType::Raw(po2(*size))),
96                _ => None,
97            },
98        }
99    }
100
101    fn reset_size(&mut self, new_size: u32) {
102        self.oratype = match self.oratype {
103            Some(OracleType::Varchar2(_)) => Some(OracleType::Varchar2(po2(new_size))),
104            Some(OracleType::NVarchar2(_)) => Some(OracleType::NVarchar2(po2(new_size))),
105            Some(OracleType::Char(_)) => Some(OracleType::Char(po2(new_size))),
106            Some(OracleType::NChar(_)) => Some(OracleType::NChar(po2(new_size))),
107            Some(OracleType::Raw(_)) => Some(OracleType::Raw(po2(new_size))),
108            _ => None,
109        };
110    }
111
112    fn as_oratype(&self) -> Option<&OracleType> {
113        self.oratype.as_ref()
114    }
115}
116
117/// A builder to create a [`Batch`] with various configuration
118pub struct BatchBuilder<'conn, 'sql> {
119    conn: &'conn Connection,
120    sql: &'sql str,
121    batch_size: usize,
122    with_batch_errors: bool,
123    with_row_counts: bool,
124    query_params: QueryParams,
125}
126
127impl<'conn, 'sql> BatchBuilder<'conn, 'sql> {
128    pub(crate) fn new(
129        conn: &'conn Connection,
130        sql: &'sql str,
131        batch_size: usize,
132    ) -> BatchBuilder<'conn, 'sql> {
133        BatchBuilder {
134            conn,
135            sql,
136            batch_size,
137            with_batch_errors: false,
138            with_row_counts: false,
139            query_params: QueryParams::new(),
140        }
141    }
142
143    /// See ["Error Handling"](Batch#error-handling)
144    pub fn with_batch_errors(&mut self) -> &mut BatchBuilder<'conn, 'sql> {
145        self.with_batch_errors = true;
146        self
147    }
148
149    /// See ["Affected Rows"](Batch#affected-rows)
150    pub fn with_row_counts(&mut self) -> &mut BatchBuilder<'conn, 'sql> {
151        self.with_row_counts = true;
152        self
153    }
154
155    pub fn build(&self) -> Result<Batch<'conn>> {
156        let batch_size = u32::try_from(self.batch_size).map_err(|err| {
157            Error::out_of_range(format!("too large batch size {}", self.batch_size)).add_source(err)
158        })?;
159        let conn = self.conn;
160        let sql = OdpiStr::new(self.sql);
161        let mut handle: *mut dpiStmt = ptr::null_mut();
162        chkerr!(
163            conn.ctxt(),
164            dpiConn_prepareStmt(
165                conn.handle(),
166                0,
167                sql.ptr,
168                sql.len,
169                ptr::null(),
170                0,
171                &mut handle
172            )
173        );
174        let mut info = MaybeUninit::uninit();
175        chkerr!(
176            conn.ctxt(),
177            dpiStmt_getInfo(handle, info.as_mut_ptr()),
178            unsafe {
179                dpiStmt_release(handle);
180            }
181        );
182        let info = unsafe { info.assume_init() };
183        if info.isDML == 0 && info.isPLSQL == 0 {
184            unsafe {
185                dpiStmt_release(handle);
186            }
187            let msg = format!(
188                "could not use {} statement",
189                StatementType::from_enum(info.statementType)
190            );
191            return Err(Error::invalid_operation(msg));
192        };
193        let mut num = 0;
194        chkerr!(
195            conn.ctxt(),
196            dpiStmt_getBindCount(handle, &mut num),
197            unsafe {
198                dpiStmt_release(handle);
199            }
200        );
201        let bind_count = num as usize;
202        let mut bind_names = Vec::with_capacity(bind_count);
203        let mut bind_values = Vec::with_capacity(bind_count);
204        if bind_count > 0 {
205            let mut names: Vec<*const c_char> = vec![ptr::null_mut(); bind_count];
206            let mut lengths = vec![0; bind_count];
207            chkerr!(
208                conn.ctxt(),
209                dpiStmt_getBindNames(handle, &mut num, names.as_mut_ptr(), lengths.as_mut_ptr()),
210                unsafe {
211                    dpiStmt_release(handle);
212                }
213            );
214            bind_names = Vec::with_capacity(num as usize);
215            for i in 0..(num as usize) {
216                bind_names.push(to_rust_str(names[i], lengths[i]));
217                bind_values.push(SqlValue::for_bind(
218                    conn.conn.clone(),
219                    self.query_params.clone(),
220                    batch_size,
221                ));
222            }
223        };
224        Ok(Batch {
225            conn,
226            handle,
227            statement_type: StatementType::from_enum(info.statementType),
228            bind_count,
229            bind_names,
230            bind_values,
231            bind_types: vec![None; bind_count],
232            batch_index: 0,
233            batch_size,
234            with_batch_errors: self.with_batch_errors,
235            with_row_counts: self.with_row_counts,
236            query_params: self.query_params.clone(),
237        })
238    }
239}
240
241/// Statement batch, which inserts, updates or deletes more than one row at once
242///
243/// Batching is efficient when the network distance between the client and
244/// the server is long. When a network round trip requires 1ms, inserting
245/// 10k rows using [`Statement`] consumes at least 10s excluding time spent
246/// in the client and the server. If 1000 rows are sent in a batch, it
247/// decreases to 10ms.
248///
249/// # Usage
250///
251/// 1. [`conn.batch(sql_stmt, batch_size).build()`](Connection::batch) to create [`Batch`].
252/// 2. [`append_row()`](#method.append_row) for each row. Rows in the batch are sent to
253///    the server when the number of appended rows reaches the batch size.  
254///    **Note:** The "batch errors" option mentioned later changes this behavior.
255/// 3. [`execute()`](#method.execute) in the end to send rows which
256///    have not been sent by `append_row()`.
257///
258/// ```
259/// # use oracle::Error;
260/// # use oracle::test_util;
261/// # let conn = test_util::connect()?;
262/// # conn.execute("delete from TestTempTable", &[])?;
263/// let sql_stmt = "insert into TestTempTable values(:1, :2)";
264/// let batch_size = 100;
265/// let mut batch = conn.batch(sql_stmt, batch_size).build()?;
266/// for i in 0..1234 { // iterate 1234 times.
267///     // send rows internally every 100 iterations.
268///     batch.append_row(&[&i, &format!("value {}", i)])?;
269/// }
270/// batch.execute()?; // send the rest 34 rows.
271/// // Check the number of inserted rows.
272/// assert_eq!(conn.query_row_as::<i32>("select count(*) from TestTempTable", &[])?, 1234);
273/// # Ok::<(), Error>(())
274/// ```
275///
276/// # Error Handling
277///
278/// There are two modes when invalid data are in a batch.
279///
280/// 1. Stop executions at the first failure and return the error information.
281/// 2. Execute all rows in the batch and return an array of the error information.
282///
283/// ## Default Error Handling
284///
285/// `append_row()` and `execute()` stop executions at the first failure and return
286/// the error information. There are no ways to know which row fails.
287///
288/// ```
289/// # use oracle::Error;
290/// # use oracle::test_util;
291/// # let conn = test_util::connect()?;
292/// # conn.execute("delete from TestTempTable", &[])?;
293/// let sql_stmt = "insert into TestTempTable values(:1, :2)";
294/// let batch_size = 10;
295/// let mut batch = conn.batch(sql_stmt, batch_size).build()?;
296/// batch.append_row(&[&1, &"first row"])?;
297/// batch.append_row(&[&2, &"second row"])?;
298/// batch.append_row(&[&1, &"first row again"])?; // -> ORA-00001: unique constraint violated.
299/// batch.append_row(&[&3, &"third row ".repeat(11)])?; // -> ORA-12899: value too large for column
300/// batch.append_row(&[&4, &"fourth row"])?;
301/// let result = batch.execute();
302/// match result {
303///     Err(Error::OciError(dberr)) => {
304///         assert_eq!(dberr.code(), 1);
305///         assert!(dberr.message().starts_with("ORA-00001: "));
306///     }
307///     _ => panic!("Unexpected batch result: {:?}", result),
308/// }
309///
310/// // Check the inserted rows.
311/// let mut stmt = conn
312///     .statement("select count(*) from TestTempTable where intCol = :1")
313///     .build()?;
314/// assert_eq!(stmt.query_row_as::<i32>(&[&1])?, 1);
315/// assert_eq!(stmt.query_row_as::<i32>(&[&2])?, 1);
316/// assert_eq!(stmt.query_row_as::<i32>(&[&3])?, 0);
317/// assert_eq!(stmt.query_row_as::<i32>(&[&4])?, 0);
318/// # Ok::<(), Error>(())
319/// ```
320///
321/// ## Error Handling with batch errors
322///
323/// **Note:** This feature is available only when both the client and the server are Oracle 12.1 or upper.
324///
325/// [`BatchBuilder::with_batch_errors`] changes
326/// the behavior of `Batch` as follows:
327/// * `execute()` executes all rows in the batch and return an array of the error information
328///   with row positions in the batch when the errors are caused by invalid data.
329/// * `append_row()` doesn't send rows internally when the number of appended rows reaches
330///   the batch size. It returns an error when the number exceeds the size instead.
331///
332/// ```
333/// # use oracle::Error;
334/// # use oracle::test_util::{self, check_version, VER12_1};
335/// # let conn = test_util::connect()?;
336/// # if !check_version(&conn, &VER12_1, &VER12_1)? {
337/// #     return Ok(()); // skip this test
338/// # }
339/// # conn.execute("delete from TestTempTable", &[])?;
340/// let sql_stmt = "insert into TestTempTable values(:1, :2)";
341/// let batch_size = 10;
342/// let mut batch = conn.batch(sql_stmt, batch_size).with_batch_errors().build()?;
343/// batch.append_row(&[&1, &"first row"])?;
344/// batch.append_row(&[&2, &"second row"])?;
345/// batch.append_row(&[&1, &"first row again"])?; // -> ORA-00001: unique constraint violated.
346/// batch.append_row(&[&3, &"third row ".repeat(11)])?; // -> ORA-12899: value too large for column
347/// batch.append_row(&[&4, &"fourth row"])?;
348/// let result = batch.execute();
349/// match result {
350///     Err(Error::BatchErrors(mut errs)) => {
351///         // sort by position because errs may not preserve order.
352///         errs.sort_by(|a, b| a.offset().cmp(&b.offset()));
353///         assert_eq!(errs.len(), 2);
354///         assert_eq!(errs[0].code(), 1);
355///         assert_eq!(errs[1].code(), 12899);
356///         assert_eq!(errs[0].offset(), 2); // position of `[&1, &"first row again"]`
357///         assert_eq!(errs[1].offset(), 3); // position of `[&3, &"third row ".repeat(11)]`
358///         assert!(errs[0].message().starts_with("ORA-00001: "));
359///         assert!(errs[1].message().starts_with("ORA-12899: "));
360///     }
361///     _ => panic!("Unexpected batch result: {:?}", result),
362/// }
363///
364/// // Check the inserted rows.
365/// let mut stmt = conn
366///     .statement("select count(*) from TestTempTable where intCol = :1")
367///     .build()?;
368/// assert_eq!(stmt.query_row_as::<i32>(&[&1])?, 1);
369/// assert_eq!(stmt.query_row_as::<i32>(&[&2])?, 1);
370/// assert_eq!(stmt.query_row_as::<i32>(&[&3])?, 0); // value too large for column
371/// assert_eq!(stmt.query_row_as::<i32>(&[&4])?, 1);
372/// # Ok::<(), Error>(())
373/// ```
374///
375/// # Affected Rows
376///
377/// **Note:** This feature is available only when both the client and the server are Oracle 12.1 or upper.
378///
379/// Use [`BatchBuilder::with_row_counts`] and [`Batch::row_counts`] to get affected rows
380/// for each input row.
381///
382/// ```
383/// # use oracle::Error;
384/// # use oracle::sql_type::OracleType;
385/// # use oracle::test_util::{self, check_version, VER12_1};
386/// # let conn = test_util::connect()?;
387/// # if !check_version(&conn, &VER12_1, &VER12_1)? {
388/// #     return Ok(()); // skip this test
389/// # }
390/// # conn.execute("delete from TestTempTable", &[])?;
391/// # let sql_stmt = "insert into TestTempTable values(:1, :2)";
392/// # let batch_size = 10;
393/// # let mut batch = conn.batch(sql_stmt, batch_size).build()?;
394/// # batch.set_type(1, &OracleType::Int64)?;
395/// # batch.set_type(2, &OracleType::Varchar2(1))?;
396/// # for i in 0..10 {
397/// #    batch.append_row(&[&i])?;
398/// # }
399/// # batch.execute()?;
400/// let sql_stmt = "update TestTempTable set stringCol = :stringCol where intCol >= :intCol";
401/// let mut batch = conn.batch(sql_stmt, 3).with_row_counts().build()?;
402/// batch.append_row_named(&[("stringCol", &"a"), ("intCol", &9)])?; // update 1 row
403/// batch.append_row_named(&[("stringCol", &"b"), ("intCol", &7)])?; // update 3 rows
404/// batch.append_row_named(&[("stringCol", &"c"), ("intCol", &5)])?; // update 5 rows
405/// batch.execute()?;
406/// assert_eq!(batch.row_counts()?, &[1, 3, 5]);
407/// # Ok::<(), Error>(())
408/// ```
409///
410/// # Bind Parameter Types
411///
412/// Parameter types are decided by the value of [`Batch::append_row`], [`Batch::append_row_named`]
413/// or [`Batch::set`]; or by the type specified by [`Batch::set_type`]. Once the
414/// type is determined, there are no ways to change it except the following case.
415///
416/// For user's convenience, when the length of character data types is too short,
417/// the length is extended automatically. For example:
418/// ```no_run
419/// # use oracle::Error;
420/// # use oracle::sql_type::OracleType;
421/// # use oracle::test_util;
422/// # let conn = test_util::connect()?;
423/// # let sql_stmt = "dummy";
424/// # let batch_size = 10;
425/// let mut batch = conn.batch(sql_stmt, batch_size).build()?;
426/// batch.append_row(&[&"first row"])?; // allocate 64 bytes for each row
427/// batch.append_row(&[&"second row"])?;
428/// //....
429/// // The following line extends the internal buffer length for each row.
430/// batch.append_row(&[&"assume that data length is over 64 bytes"])?;
431/// # Ok::<(), Error>(())
432/// ```
433/// Note that extending the internal buffer needs memory copy from existing buffer
434/// to newly allocated buffer. If you know the maximum data length, it is better
435/// to set the size by [`Batch::set_type`].
436pub struct Batch<'conn> {
437    pub(crate) conn: &'conn Connection,
438    handle: *mut dpiStmt,
439    statement_type: StatementType,
440    bind_count: usize,
441    bind_names: Vec<String>,
442    bind_values: Vec<SqlValue<'conn>>,
443    bind_types: Vec<Option<BindType>>,
444    batch_index: u32,
445    batch_size: u32,
446    with_batch_errors: bool,
447    with_row_counts: bool,
448    query_params: QueryParams,
449}
450
451impl Batch<'_> {
452    /// Closes the batch before the end of its lifetime.
453    pub fn close(&mut self) -> Result<()> {
454        chkerr!(self.conn.ctxt(), dpiStmt_close(self.handle, ptr::null(), 0));
455        Ok(())
456    }
457
458    pub fn append_row(&mut self, params: &[&dyn ToSql]) -> Result<()> {
459        self.check_batch_index()?;
460        for (i, param) in params.iter().enumerate() {
461            self.bind_internal(i + 1, *param)?;
462        }
463        self.append_row_common()
464    }
465
466    pub fn append_row_named(&mut self, params: &[(&str, &dyn ToSql)]) -> Result<()> {
467        self.check_batch_index()?;
468        for param in params {
469            self.bind_internal(param.0, param.1)?;
470        }
471        self.append_row_common()
472    }
473
474    fn append_row_common(&mut self) -> Result<()> {
475        if self.with_batch_errors {
476            self.set_batch_index(self.batch_index + 1);
477        } else {
478            self.set_batch_index(self.batch_index + 1);
479            if self.batch_index == self.batch_size {
480                self.execute()?;
481            }
482        }
483        Ok(())
484    }
485
486    pub fn execute(&mut self) -> Result<()> {
487        let result = self.execute_sub();
488        // reset all values to null regardless of the result
489        let num_rows = self.batch_index;
490        self.batch_index = 0;
491        for bind_value in &mut self.bind_values {
492            for i in 0..num_rows {
493                bind_value.buffer_row_index = BufferRowIndex::Owned(i);
494                bind_value.set_null()?;
495            }
496            bind_value.buffer_row_index = BufferRowIndex::Owned(0);
497        }
498        result
499    }
500
501    fn execute_sub(&mut self) -> Result<()> {
502        if self.batch_index == 0 {
503            return Ok(());
504        }
505        let mut exec_mode = DPI_MODE_EXEC_DEFAULT;
506        if self.conn.autocommit() {
507            exec_mode |= DPI_MODE_EXEC_COMMIT_ON_SUCCESS;
508        }
509        if self.with_batch_errors {
510            exec_mode |= DPI_MODE_EXEC_BATCH_ERRORS;
511        }
512        if self.with_row_counts {
513            exec_mode |= DPI_MODE_EXEC_ARRAY_DML_ROWCOUNTS;
514        }
515        chkerr!(
516            self.conn.ctxt(),
517            dpiStmt_executeMany(self.handle, exec_mode, self.batch_index)
518        );
519        self.conn.ctxt().set_warning();
520        if self.with_batch_errors {
521            let mut errnum = 0;
522            chkerr!(
523                self.conn.ctxt(),
524                dpiStmt_getBatchErrorCount(self.handle, &mut errnum)
525            );
526            if errnum != 0 {
527                let mut errs = Vec::with_capacity(errnum as usize);
528                chkerr!(
529                    self.conn.ctxt(),
530                    dpiStmt_getBatchErrors(self.handle, errnum, errs.as_mut_ptr())
531                );
532                unsafe { errs.set_len(errnum as usize) };
533                return Err(Error::make_batch_errors(
534                    errs.iter().map(DbError::from_dpi_error).collect(),
535                ));
536            }
537        }
538        Ok(())
539    }
540
541    /// Returns the number of bind parameters
542    ///
543    /// ```
544    /// # use oracle::Error;
545    /// # use oracle::test_util;
546    /// # let conn = test_util::connect()?;
547    /// # conn.execute("delete from TestTempTable", &[])?;
548    /// let sql_stmt = "insert into TestTempTable values(:intCol, :stringCol)";
549    /// let mut batch = conn.batch(sql_stmt, 100).build()?;
550    /// assert_eq!(batch.bind_count(), 2);
551    /// # Ok::<(), Error>(())
552    /// ```
553    pub fn bind_count(&self) -> usize {
554        self.bind_count
555    }
556
557    /// Returns an array of bind parameter names
558    ///
559    /// ```
560    /// # use oracle::Error;
561    /// # use oracle::test_util;
562    /// # let conn = test_util::connect()?;
563    /// # conn.execute("delete from TestTempTable", &[])?;
564    /// let sql_stmt = "insert into TestTempTable values(:intCol, :stringCol)";
565    /// let batch = conn.batch(sql_stmt, 100).build()?;
566    /// assert_eq!(batch.bind_names(), &["INTCOL", "STRINGCOL"]);
567    /// # Ok::<(), Error>(())
568    /// ```
569    pub fn bind_names(&self) -> Vec<&str> {
570        self.bind_names.iter().map(|name| name.as_str()).collect()
571    }
572
573    fn check_batch_index(&self) -> Result<()> {
574        if self.batch_index < self.batch_size {
575            Ok(())
576        } else {
577            Err(Error::out_of_range(format!(
578                "over the max batch size {}",
579                self.batch_size
580            )))
581        }
582    }
583
584    /// Set the data type of a bind parameter
585    ///
586    /// ```
587    /// # use oracle::Error;
588    /// # use oracle::test_util;
589    /// # use oracle::sql_type::OracleType;
590    /// # let conn = test_util::connect()?;
591    /// # conn.execute("delete from TestTempTable", &[])?;
592    /// let sql_stmt = "insert into TestTempTable values(:intCol, :stringCol)";
593    /// let mut batch = conn.batch(sql_stmt, 100).build()?;
594    /// batch.set_type(1, &OracleType::Int64)?;
595    /// batch.set_type(2, &OracleType::Varchar2(10))?;
596    /// # Ok::<(), Error>(())
597    /// ```
598    pub fn set_type<I>(&mut self, bindidx: I, oratype: &OracleType) -> Result<()>
599    where
600        I: BatchBindIndex,
601    {
602        let pos = bindidx.idx(self)?;
603        if self.bind_types[pos].is_some() {
604            return Err(Error::invalid_operation(format!(
605                "type at {} has set already",
606                bindidx
607            )));
608        }
609        self.bind_values[pos].init_handle(oratype)?;
610        chkerr!(
611            self.conn.ctxt(),
612            bindidx.bind(self.handle, self.bind_values[pos].handle()?)
613        );
614        self.bind_types[pos] = Some(BindType::new(oratype));
615        Ok(())
616    }
617
618    /// Set a parameter value
619    ///
620    /// ```
621    /// # use oracle::Error;
622    /// # use oracle::test_util;
623    /// # let conn = test_util::connect()?;
624    /// # conn.execute("delete from TestTempTable", &[])?;
625    /// let sql_stmt = "insert into TestTempTable values(:intCol, :stringCol)";
626    /// let mut batch = conn.batch(sql_stmt, 100).build()?;
627    /// // The below three lines are same with `batch.append_row(&[&100, &"hundred"])?`.
628    /// batch.set(1, &100)?; // set by position 1
629    /// batch.set(2, &"hundred")?; // set at position 2
630    /// batch.append_row(&[])?;
631    /// // The below three lines are same with `batch.append_row(&[("intCol", &101), ("stringCol", &"hundred one")])?`
632    /// batch.set("intCol", &101)?; // set by name "intCol"
633    /// batch.set("stringCol", &"hundred one")?; // set by name "stringCol"
634    /// batch.append_row(&[])?;
635    /// batch.execute()?;
636    /// let sql_stmt = "select * from TestTempTable where intCol = :1";
637    /// assert_eq!(conn.query_row_as::<(i32, String)>(sql_stmt, &[&100])?, (100, "hundred".to_string()));
638    /// assert_eq!(conn.query_row_as::<(i32, String)>(sql_stmt, &[&101])?, (101, "hundred one".to_string()));
639    /// # Ok::<(), Error>(())
640    /// ```
641    pub fn set<I>(&mut self, index: I, value: &dyn ToSql) -> Result<()>
642    where
643        I: BatchBindIndex,
644    {
645        self.check_batch_index()?;
646        self.bind_internal(index, value)
647    }
648
649    fn bind_internal<I>(&mut self, bindidx: I, value: &dyn ToSql) -> Result<()>
650    where
651        I: BatchBindIndex,
652    {
653        let pos = bindidx.idx(self)?;
654        if self.bind_types[pos].is_none() {
655            // When the parameter type has not bee specified yet,
656            // assume the type from the value
657            let oratype = value.oratype(self.conn)?;
658            let bind_type = BindType::new(&oratype);
659            self.bind_values[pos].init_handle(bind_type.as_oratype().unwrap_or(&oratype))?;
660            chkerr!(
661                self.conn.ctxt(),
662                bindidx.bind(self.handle, self.bind_values[pos].handle()?)
663            );
664            self.bind_types[pos] = Some(bind_type);
665        }
666        match self.bind_values[pos].set(value) {
667            Err(err) if err.dpi_code() == Some(DPI_ERR_BUFFER_SIZE_TOO_SMALL) => {
668                let bind_type = self.bind_types[pos].as_mut().unwrap();
669                if bind_type.as_oratype().is_none() {
670                    return Err(err);
671                }
672                let new_oratype = value.oratype(self.conn)?;
673                let new_size = oratype_size(&new_oratype).ok_or(err)?;
674                bind_type.reset_size(new_size);
675                // allocate new bind handle.
676                let mut new_sql_value = SqlValue::for_bind(
677                    self.conn.conn.clone(),
678                    self.query_params.clone(),
679                    self.batch_size,
680                );
681                new_sql_value.init_handle(bind_type.as_oratype().unwrap())?;
682                // copy values in old to new.
683                for idx in 0..self.batch_index {
684                    chkerr!(
685                        self.conn.ctxt(),
686                        dpiVar_copyData(
687                            new_sql_value.handle()?,
688                            idx,
689                            self.bind_values[pos].handle()?,
690                            idx
691                        )
692                    );
693                }
694                new_sql_value.buffer_row_index = BufferRowIndex::Owned(self.batch_index);
695                new_sql_value.set(value)?;
696                chkerr!(
697                    self.conn.ctxt(),
698                    bindidx.bind(self.handle, new_sql_value.handle()?)
699                );
700                self.bind_values[pos] = new_sql_value;
701                Ok(())
702            }
703            x => x,
704        }
705    }
706
707    fn set_batch_index(&mut self, batch_index: u32) {
708        self.batch_index = batch_index;
709        for bind_value in &mut self.bind_values {
710            bind_value.buffer_row_index = BufferRowIndex::Owned(batch_index);
711        }
712    }
713
714    /// Returns the number of affected rows
715    ///
716    /// See ["Affected Rows"](Batch#affected-rows)
717    pub fn row_counts(&self) -> Result<Vec<u64>> {
718        let mut num_row_counts = 0;
719        let mut row_counts = ptr::null_mut();
720        chkerr!(
721            self.conn.ctxt(),
722            dpiStmt_getRowCounts(self.handle, &mut num_row_counts, &mut row_counts)
723        );
724        Ok(unsafe { slice::from_raw_parts(row_counts, num_row_counts as usize) }.to_vec())
725    }
726
727    /// Returns statement type
728    pub fn statement_type(&self) -> StatementType {
729        self.statement_type
730    }
731
732    /// Returns true when the SQL statement is a PL/SQL block.
733    pub fn is_plsql(&self) -> bool {
734        matches!(
735            self.statement_type,
736            StatementType::Begin | StatementType::Declare | StatementType::Call
737        )
738    }
739
740    /// Returns true when the SQL statement is DML (data manipulation language).
741    pub fn is_dml(&self) -> bool {
742        matches!(
743            self.statement_type,
744            StatementType::Insert
745                | StatementType::Update
746                | StatementType::Delete
747                | StatementType::Merge
748        )
749    }
750}
751
752impl Drop for Batch<'_> {
753    fn drop(&mut self) {
754        unsafe { dpiStmt_release(self.handle) };
755    }
756}
757
758/// A trait implemented by types that can index into bind values of a batch
759///
760/// This trait is sealed and cannot be implemented for types outside of the `oracle` crate.
761pub trait BatchBindIndex: private::Sealed + fmt::Display {
762    /// Returns the index of the bind value specified by `self`.
763    #[doc(hidden)]
764    fn idx(&self, batch: &Batch) -> Result<usize>;
765    /// Binds the specified value by using a private method.
766    #[doc(hidden)]
767    unsafe fn bind(&self, stmt_handle: *mut dpiStmt, var_handle: *mut dpiVar) -> i32;
768}
769
770impl BatchBindIndex for usize {
771    #[doc(hidden)]
772    fn idx(&self, batch: &Batch) -> Result<usize> {
773        let num = batch.bind_count();
774        if 0 < num && *self <= num {
775            Ok(*self - 1)
776        } else {
777            Err(Error::invalid_bind_index(*self))
778        }
779    }
780
781    #[doc(hidden)]
782    unsafe fn bind(&self, stmt_handle: *mut dpiStmt, var_handle: *mut dpiVar) -> i32 {
783        dpiStmt_bindByPos(stmt_handle, *self as u32, var_handle)
784    }
785}
786
787impl BatchBindIndex for &str {
788    #[doc(hidden)]
789    fn idx(&self, batch: &Batch) -> Result<usize> {
790        let bindname = self.to_uppercase();
791        batch
792            .bind_names()
793            .iter()
794            .position(|&name| name == bindname)
795            .ok_or_else(|| Error::invalid_bind_name(*self))
796    }
797
798    #[doc(hidden)]
799    unsafe fn bind(&self, stmt_handle: *mut dpiStmt, var_handle: *mut dpiVar) -> i32 {
800        let s = OdpiStr::new(self);
801        dpiStmt_bindByName(stmt_handle, s.ptr, s.len, var_handle)
802    }
803}
804
805#[cfg(test)]
806mod tests {
807    use super::*;
808    use crate::test_util;
809    use crate::ErrorKind;
810
811    #[derive(Debug)]
812    struct TestData {
813        int_val: i32,
814        string_val: &'static str,
815        error_code: Option<i32>,
816    }
817
818    impl TestData {
819        const fn new(int_val: i32, string_val: &'static str, error_code: Option<i32>) -> TestData {
820            TestData {
821                int_val,
822                string_val,
823                error_code,
824            }
825        }
826    }
827
828    // ORA-00001: unique constraint violated
829    const ERROR_UNIQUE_INDEX_VIOLATION: Option<i32> = Some(1);
830
831    // ORA-12899: value too large for column
832    const ERROR_TOO_LARGE_VALUE: Option<i32> = Some(12899);
833
834    const TEST_DATA: [TestData; 10] = [
835        TestData::new(0, "0", None),
836        TestData::new(1, "1111", None),
837        TestData::new(2, "222222222222", None),
838        TestData::new(3, "3333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333", None),
839        TestData::new(4, "44444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444", ERROR_TOO_LARGE_VALUE),
840        TestData::new(1, "55555555555555", ERROR_UNIQUE_INDEX_VIOLATION),
841        TestData::new(6, "66666666666", None),
842        TestData::new(2, "7", ERROR_UNIQUE_INDEX_VIOLATION),
843        TestData::new(8, "8", None),
844        TestData::new(3, "9999999999999999999999999", ERROR_UNIQUE_INDEX_VIOLATION),
845    ];
846
847    fn append_rows_then_execute(batch: &mut Batch, rows: &[&TestData]) -> Result<()> {
848        for row in rows {
849            batch.append_row(&[&row.int_val, &row.string_val])?;
850        }
851        batch.execute()?;
852        Ok(())
853    }
854
855    fn check_rows_inserted(conn: &Connection, expected_rows: &[&TestData]) -> Result<()> {
856        let mut rows =
857            conn.query_as::<(i32, String)>("select * from TestTempTable order by intCol", &[])?;
858        let mut expected_rows = expected_rows.to_vec();
859        expected_rows.sort_by(|a, b| a.int_val.cmp(&b.int_val));
860        for expected_row in expected_rows {
861            let row_opt = rows.next();
862            assert!(row_opt.is_some());
863            let row = row_opt.unwrap()?;
864            assert_eq!(row.0, expected_row.int_val);
865            assert_eq!(row.1, expected_row.string_val);
866        }
867        assert!(rows.next().is_none());
868        Ok(())
869    }
870
871    #[test]
872    fn batch_insert() {
873        let conn = test_util::connect().unwrap();
874        let rows: Vec<&TestData> = TEST_DATA
875            .iter()
876            .filter(|data| data.error_code.is_none())
877            .collect();
878        let mut batch = conn
879            .batch("insert into TestTempTable values(:1, :2)", rows.len())
880            .build()
881            .unwrap();
882        append_rows_then_execute(&mut batch, &rows).unwrap();
883        check_rows_inserted(&conn, &rows).unwrap();
884    }
885
886    #[test]
887    fn batch_execute_twice() {
888        let conn = test_util::connect().unwrap();
889        let rows_total: Vec<&TestData> = TEST_DATA
890            .iter()
891            .filter(|data| data.error_code.is_none())
892            .collect();
893        let (rows_first, rows_second) = rows_total.split_at(rows_total.len() / 2);
894        let mut batch = conn
895            .batch("insert into TestTempTable values(:1, :2)", rows_first.len())
896            .build()
897            .unwrap();
898        append_rows_then_execute(&mut batch, rows_first).unwrap();
899        append_rows_then_execute(&mut batch, rows_second).unwrap();
900        check_rows_inserted(&conn, &rows_total).unwrap();
901    }
902
903    #[test]
904    fn batch_with_error() {
905        let conn = test_util::connect().unwrap();
906        let rows: Vec<&TestData> = TEST_DATA.iter().collect();
907        let expected_rows: Vec<&TestData> = TEST_DATA
908            .iter()
909            .take_while(|data| data.error_code.is_none())
910            .collect();
911        let mut batch = conn
912            .batch("insert into TestTempTable values(:1, :2)", rows.len())
913            .build()
914            .unwrap();
915        match append_rows_then_execute(&mut batch, &rows) {
916            Err(err) if err.kind() == ErrorKind::OciError => {
917                let errcode = TEST_DATA
918                    .iter()
919                    .find(|data| data.error_code.is_some())
920                    .unwrap()
921                    .error_code;
922                assert_eq!(err.oci_code(), errcode);
923            }
924            x => {
925                panic!("got {:?}", x);
926            }
927        }
928        check_rows_inserted(&conn, &expected_rows).unwrap();
929    }
930
931    #[test]
932    fn batch_with_batch_errors() {
933        let conn = test_util::connect().unwrap();
934        let rows: Vec<&TestData> = TEST_DATA.iter().collect();
935        let expected_rows: Vec<&TestData> = TEST_DATA
936            .iter()
937            .filter(|row| row.error_code.is_none())
938            .collect();
939        let mut batch = conn
940            .batch("insert into TestTempTable values(:1, :2)", rows.len())
941            .with_batch_errors()
942            .build()
943            .unwrap();
944        match append_rows_then_execute(&mut batch, &rows) {
945            Err(err) if err.batch_errors().is_some() => {
946                let expected_errors: Vec<(u32, i32)> = TEST_DATA
947                    .iter()
948                    .enumerate()
949                    .filter(|row| row.1.error_code.is_some())
950                    .map(|row| (row.0 as u32, row.1.error_code.unwrap()))
951                    .collect();
952                let actual_errors: Vec<(u32, i32)> = err
953                    .batch_errors()
954                    .unwrap()
955                    .iter()
956                    .map(|dberr| (dberr.offset(), dberr.code()))
957                    .collect();
958                assert_eq!(expected_errors, actual_errors);
959            }
960            x => {
961                panic!("got {:?}", x);
962            }
963        }
964        check_rows_inserted(&conn, &expected_rows).unwrap();
965    }
966}